CGAN 原理及实现

  • 一、CGAN 原理
    • 1.1 基本概念
    • 1.2 与传统GAN的区别
    • 1.3 目标函数
    • 1.4 损失函数
    • 1.5 条件信息的融合方式
    • 1.6 与其他GAN变体的对比
    • 1.7 CGAN的应用
    • 1.8 改进与变体
  • 二、CGAN 实现
    • 2.1 导包
    • 2.2 数据加载和处理
    • 2.3 构建生成器
    • 2.4 构建判别器
    • 2.5 训练和保存模型
    • 2.6 绘制训练损失
    • 2.7 图片转GIF
    • 2.8 模型加载和生成

一、CGAN 原理

1.1 基本概念

条件生成对抗网络(Conditional GAN, CGAN)是GAN的一种扩展,它在生成器和判别器中都加入了额外的条件信息 yyy这个条件信息可以是类别标签、文本描述或其他形式的辅助信息

1.2 与传统GAN的区别

  • 传统GAN: G(z)G(z)G(z) → 生成样本,D(x)D(x)D(x) → 判断真实/生成
  • CGAN: G(z∣y)G(z|y)G(zy) → 基于条件 yyy 生成样本,D(x∣y)D(x|y)D(xy) → 基于条件 yyy 判断真实/生成

1.3 目标函数

CGAN的目标函数可以表示为:minGmaxDV(D,G)=𝔼x∼pdata[logD(x∣y)]+𝔼z∼pz(z)[log(1−D(G(z∣y)∣y))]min_G max_D V(D,G) = 𝔼_{x \sim p_{\text{data}}}[log D(x|y)] + 𝔼_{z \sim p_z(z)}[log(1 - D(G(z|y)|y))]minGmaxDV(D,G)=Expdata[logD(xy)]+Ezpz(z)[log(1D(G(zy)y))],其中 yyy 是条件信息。

1.4 损失函数

(1) 判别器(Discriminator)的损失函数

 \space  \space 判别器需要同时判断:

  1. 真实图像是否真实(且匹配其标签)
  2. 生成图像是否虚假(且匹配其标签)

损失函数公式

LD=Ex,y∼pdata[log⁡D(x∣y)]⏟真实样本损失+Ez∼pz,y∼plabels[log⁡(1−D(G(z∣y)∣y)]⏟生成样本损失\mathcal{L}_D = \underbrace{\mathbb{E}_{x,y \sim p_{\text{data}}}[\log D(x|y)]}_{\text{真实样本损失}} + \underbrace{\mathbb{E}_{z \sim p_z, y \sim p_{\text{labels}}}[\log (1 - D(G(z|y)|y)]}_{\text{生成样本损失}}LD=真实样本损失 Ex,ypdata[logD(xy)]+生成样本损失 Ezpz,yplabels[log(1D(G(zy)y)]

(2)生成器(Generator)的损失函数

 \space  生成器的目标是欺骗判别器,使其认为生成的图像是真实的(且匹配条件标签 yyy)。

损失函数公式

LG=Ez∼pz,y∼plabels[log⁡(1−D(G(z∣y)∣y)]⏟原始形式或−Ez,y[log⁡D(G(z∣y)∣y)]⏟改进形式\mathcal{L}_G = \underbrace{\mathbb{E}_{z \sim p_z, y \sim p_{\text{labels}}}[\log (1 - D(G(z|y)|y)]}_{\text{原始形式}} \quad \text{或} \quad \underbrace{-\mathbb{E}_{z,y}[\log D(G(z|y)|y)]}_{\text{改进形式}}LG=原始形式 Ezpz,yplabels[log(1D(G(zy)y)]改进形式 Ez,y[logD(G(zy)y)]

1.5 条件信息的融合方式

在损失计算中,条件标签通过以下方式参与:

  1. 生成器输入:噪声 z 和标签 y 拼接后输入生成器
    gen_input = torch.cat([z, label_embed], dim=1)  # z: [batch, z_dim], label_embed: [batch, embed_dim]
    
  2. 判别器输入:图像和标签拼接后输入判别器
    # 图像x: [batch, C, H, W], 标签扩展为 [batch, embed_dim, H, W]
    label_expanded = label_embed.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, H, W)
    disc_input = torch.cat([x, label_expanded], dim=1)  # 沿通道维度拼接
    

1.6 与其他GAN变体的对比

损失函数特性 标准GAN 条件GAN(CGAN) WGAN-GP
判别器输出 概率值 (0~1) 条件概率值 未限制的分数
生成器目标 欺骗判别器 生成符合标签的图像 最小化Wasserstein距离
梯度稳定性 易崩溃 依赖条件强度 通过梯度惩罚稳定

1.7 CGAN的应用

  1. 图像生成:根据类别标签生成特定类型的图像
  2. 图像到图像转换:如将语义标签图转换为真实图像
  3. 文本到图像生成:根据文本描述生成图像
  4. 数据增强:为特定类别生成额外的训练样本

1.8 改进与变体

  1. AC-GAN:辅助分类器GAN,在判别器中增加分类任务
  2. InfoGAN:学习可解释的潜在表示
  3. StackGAN:分阶段生成高分辨率图像
  4. ProGAN:渐进式生成高分辨率图像

条件GAN通过引入条件信息使得生成过程更加可控,能够生成特定类别的样本,在实际应用中具有广泛的用途


二、CGAN 实现

2.1 导包

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np

import os
import time
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm  

# 判断是否存在可用的GPU
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 设置日志
time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) # 生成当前时间格式(例如:2024-03-15_14-30-00)
log_dir = os.path.join("./logs/cgan", time_str) # 设置日志路径,格式如:./logs/cgan/2024-03-15_14-30-00
os.makedirs(log_dir, exist_ok=True) # 自动创建目录
writer = SummaryWriter(log_dir=log_dir) # 初始化 SummaryWriter

os.makedirs("./img/cgan_mnist", exist_ok=True) # 存放生成样本目录
os.makedirs("./model", exist_ok=True) # 模型存放目录

2.2 数据加载和处理

# 加载 MNIST 数据集
def load_data(batch_size=64,img_shape=(1,32,32)):
    transform = transforms.Compose([
        transforms.Resize((img_shape[1],img_shape[2])),
        transforms.ToTensor(),  # 将图像转换为张量
        transforms.Normalize(mean=[0.5], std=[0.5])  # 归一化到[-1,1]
    ])
    
    # 下载训练集和测试集
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    # 创建 DataLoader
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=2,shuffle=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=2,shuffle=False)
    return train_loader, test_loader

2.3 构建生成器

class Generator(nn.Module):
    """生成器"""
    def __init__(self, img_shape=(1,32,32),latent_dim=100,num_classes=10,label_embed_dim=10):
        """
        Args:
            img_shape (int, optional): 
                生成图片大小,默认CHW=1*32*32
            
            latent_dim (int, optional): 
                潜在噪声向量的维度。默认100维,作为生成器的随机输入种子。             
                
            num_classes (int, optional): 
                类别数量。默认10(例如MNIST的0-9数字分类)。
                决定标签嵌入矩阵的行数。
                
            label_embed_dim (int, optional): 
                标签嵌入向量的维度。默认10维。
                将离散标签映射为连续向量的维度,影响条件信息的表达能力。               
        """
        super(Generator, self).__init__()

        # 定义嵌入层 [batch_szie]-> [batch_size,label_embed_dim]=[64,10]
        self.label_embed = nn.Embedding(num_classes, label_embed_dim) # num_classes 个类别, label_embed_dim 维嵌入

        # 定义网络块
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))
            return layers

        # 定义模型架构
        self.model = nn.Sequential(
            *block(latent_dim + label_embed_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))), # [batch_size,1024]-> [batch_size,1*32*32]
            nn.Tanh() # 输出归一化到[-1,1] 
        )
        
    def forward(self, z, labels):
        # 嵌入标签 [batch_size]-> [batch_size,label_embed_dim]=[64,10]
        label_embed = self.label_embed(labels)
        # 拼接嵌入标签和噪声 ->[batch_size,latent_dim + label_embed_dim]=[64,100+10]
        gen_input = torch.cat([label_embed, z], dim=1)
        # 生成图片-> [batch_size,C,H,W]=[64,1,32,32]
        img = self.model(gen_input) # -> [batch_size,C*H*W]=[64,1*32*32]
        img = img.view(img.shape[0], *img_shape) # [batch_size,C*H*W]-> [batch_size,C,H,W]=[64,1,32,32]
        return img 

2.4 构建判别器

class Discriminator(nn.Module):
    """判别器"""
    def __init__(self, img_shape=(1,32,32),label_embed_dim=10):
        
        super(Discriminator, self).__init__()

         # 定义嵌入层 [batch_szie]-> [batch_size,label_embed_dim]=[64,10]
        self.label_embed = nn.Embedding(num_classes, label_embed_dim) # num_classes 个类别, label_embed_dim 维嵌入

        # 定义模型结构
        self.model = nn.Sequential(
            nn.Linear(label_embed_dim+ int(np.prod(img_shape)), 512), # [64,10+1*32*32]-> [64,512]
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Linear(512, 1),
        )

    def forward(self, img, labels):
        # 嵌入标签 [batch_size]-> [batch_size,label_embed_dim]=[64,10]
        label_embed = self.label_embed(labels)
        # 输入图片展平[64,1,32,32]-> [64,1*32*32]
        img=img.view(img.shape[0], -1)
        # 拼接嵌入标签和输入图片 ->[batch_size,label_embed_dim + C*H*W]=[64,10+1*32*32]
        dis_input = torch.cat([label_embed, img], dim=1)
        # 进行判定
        validity = self.model(dis_input)
        return validity # -> [64,1]

2.5 训练和保存模型

1. 定义保存生成样本

def sample_image(G,n_row, batches_done,latent_dim=100,device=device):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # 随机噪声-> [n_row ** 2,latent_dim]=[100,100]
    z=torch.normal(0,1,size=(n_row ** 2,latent_dim),device=device)  #从正态分布中抽样
    # 条件标签->[100]
    labels = torch.arange(n_row, dtype=torch.long, device=device).repeat_interleave(n_row)
    gen_imgs = G(z, labels)
    save_image(gen_imgs.data, "./img/cgan_mnist/%d.png" % batches_done, nrow=n_row, normalize=True)

2. 训练和保存

# 设置超参数
batch_size = 64
epochs = 200
lr= 0.0002
latent_dim=100 # 生成器输入噪声向量的长度(维数)
sample_interval=400 #每400次迭代保存生成样本
img_shape = (1,32,32) # 图片大小
num_classes=10 # 分类数
label_embed_dim=10 # 嵌入维数

# 加载数据
train_loader,_= load_data(batch_size=batch_size,img_shape=img_shape)

# 实例化生成器G、判别器D
G=Generator().to(device)
D=Discriminator().to(device)

# 设置优化器
optimizer_G = torch.optim.Adam(G.parameters(), lr=lr,betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=lr,betas=(0.5, 0.999))

# 损失函数
loss_fn=nn.BCEWithLogitsLoss()

# 开始训练
dis_costs,gen_costs = [],[] # 记录生成器和判别器每次迭代的开销(损失)
start_time = time.time()  # 计时器
loader_len=len(train_loader) #训练集加载器的长度
for epoch in range(epochs):
    # 进入训练模式
    G.train()
    D.train()
     
    #记录生成器G和判别器D的总损失(1个 epoch 内)
    gen_loss_sum,dis_loss_sum=0.0,0.0
    
    loop = tqdm(train_loader, desc=f"第{epoch+1}轮")
    for i, (real_imgs, real_labels) in enumerate(loop):
        real_imgs=real_imgs.to(device)  # [B,C,H,W]
        real_labels=real_labels.to(device) # [B]

         # 平滑真假标签,2维[B,1]
        valid_labels = torch.empty(real_imgs.shape[0], 1, device=device).uniform_(0.9, 1.0).requires_grad_(False) # 替代1.0
        fake_labels = torch.empty(real_imgs.shape[0], 1, device=device).uniform_(0.0, 0.1).requires_grad_(False) # 替代0.0

        # -----------------
        #  训练生成器
        # -----------------
        
        # 获取噪声样本[batch_size,latent_dim]及对应的条件标签 [batch_size]
        z=torch.normal(0,1,size=(real_imgs.shape[0],latent_dim),device=device)  #从正态分布中抽样
        gen_labels = torch.randint(0, num_classes, (real_imgs.shape[0],), device=device, dtype=torch.long) # 0~9整数之间,随机抽 real_imgs.shape[0]次

        # 计算生成器损失
        gen_imgs=G(z,gen_labels)
        gen_loss=loss_fn(D(gen_imgs,gen_labels),valid_labels)

        # 更新生成器参数
        optimizer_G.zero_grad() #梯度清零
        gen_loss.backward() #反向传播,计算梯度
        optimizer_G.step()  #更新生成器

        # -----------------
        #  训练判别器
        # -----------------

        # 计算判别器损失
        # Step-1:对真实图片损失
        valid_loss=loss_fn(D(real_imgs,real_labels),valid_labels)
        # Step-2:对生成图片损失
        fake_loss=loss_fn(D(gen_imgs.detach(),gen_labels),fake_labels)
        # Step-3:整体损失
        dis_loss=(valid_loss+fake_loss)/2.0

        # 更新判别器参数
        optimizer_D.zero_grad() #梯度清零
        dis_loss.backward() #反向传播,计算梯度
        optimizer_D.step()  #更新判断器  


        # 对生成器和判别器每次迭代的损失进行累加
        gen_loss_sum+=gen_loss
        dis_loss_sum+=dis_loss
		
		gen_costs.append(gen_loss.item())
        dis_costs.append(dis_loss.item())

        # 每 sample_interval 次迭代保存生成样本
        batches_done = epoch * loader_len + i
        if batches_done % sample_interval == 0:
            sample_image(G=G,n_row=10, batches_done=batches_done)
            # 更新进度条
            loop.set_postfix(mean_gen_loss=f"{gen_loss_sum/(loop.n + 1):.8f}",mean_dis_loss=f"{dis_loss_sum/(loop.n + 1):.8f}")
            
            writer.add_scalars(
                main_tag="Train Losses",  
                tag_scalar_dict={
                    "Generator": gen_loss,
                    "Discriminator": dis_loss,
                },
                global_step=batches_done  # X轴坐标
            )
writer.close()
print('总共训练用时: %.2f min' % ((time.time() - start_time)/60))

#仅保存模型的参数(权重和偏置),灵活性高,可以在不同的模型结构之间加载参数
torch.save(G.state_dict(), "./model/CGAN_G.pth") 
torch.save(D.state_dict(), "./model/CGAN_D.pth") 

2.6 绘制训练损失

# 创建画布
plt.figure(figsize=(10, 5))
ax1 = plt.subplot(1, 1, 1)

# 绘制曲线
ax1.plot(range(len(gen_costs)), gen_costs, label='Generator loss', linewidth=2)
ax1.plot(range(len(dis_costs)), dis_costs, label='Discriminator loss', linewidth=2)

ax1.set_xlabel('Iterations', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('CGAN Training Loss', fontsize=14)
ax1.legend(fontsize=10)
ax1.grid(True, linestyle='--', alpha=0.6)

ax2 = ax1.twiny()  # 创建共享Y轴的新X轴
newlabel = list(range(epochs+1))  # Epoch标签 [0,1,2,...]
iter_per_epoch = len(train_loader)  # 每个epoch的iteration次数
newpos = [e*iter_per_epoch for e in newlabel]  # 计算Epoch对应的iteration位置

ax2.set_xticks(newpos[::10])  
ax2.set_xticklabels(newlabel[::10])  

ax2.xaxis.set_ticks_position('bottom')  
ax2.xaxis.set_label_position('bottom')  
ax2.spines['bottom'].set_position(('outward', 45))  # 坐标轴下移45点
ax2.set_xlabel('Epochs')  # 设置轴标签
ax2.set_xlim(ax1.get_xlim())  # 与主X轴范围同步

plt.tight_layout()  
plt.savefig('cgan_loss.png', dpi=300)  
plt.show()  

2.7 图片转GIF

from PIL import Image

def create_gif(img_dir="./img/cgan_mnist", output_file="./img/cgan_mnist/cgan_figure.gif", duration=100):
    images = []
    img_paths = [f for f in os.listdir(img_dir) if f.endswith(".png")]
    
    img_paths_sorted = sorted(
        img_paths,
        key=lambda x: (
            int(x.split('.')[0]),  # (如 400.png 的 400)
        )
    )
    
    for img_file in img_paths_sorted:
        img = Image.open(os.path.join(img_dir, img_file))
        images.append(img)
    
    images[0].save(output_file, save_all=True, append_images=images[1:], 
                  duration=duration, loop=0)
    print(f"GIF已保存至 {output_file}")
create_gif()

2.8 模型加载和生成

#载入训练好的模型
G = Generator() # 定义模型结构
G.load_state_dict(torch.load("./model/CGAN_G.pth",weights_only=True,map_location=device)) # 加载保存的参数
G.to(device) # 将模型移动到设备(GPU 或 CPU)
G.eval() # 将模型设置为评估模式


# 获取噪声样本[10,100]及对应的条件标签 [10]
z=torch.normal(0,1,size=(10,100),device=device)  #从正态分布中抽样
gen_labels = torch.arange(10, dtype=torch.long, device=device) #0~9整数


#生成假样本
gen_imgs=G(z,gen_labels).view(-1,32,32) # 4维->3维
gen_imgs=gen_imgs.detach().cpu().numpy()


# #绘制
plt.figure(figsize=(3, 2)) 
for i in range(10):
    plt.subplot(2, 5, i + 1)  
    plt.xticks([], [])  
    plt.yticks([], [])  
    plt.imshow(gen_imgs[i], cmap='gray')  
    plt.title(f"Figure {i}", fontsize=5)  
plt.tight_layout()  
plt.show()
Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐