条件生成对抗网络(Conditional GAN, CGAN)原理及实现(pytorch版)
条件生成对抗网络(Conditional GAN, cGAN)是GAN的一种扩展,它在生成器和判别器中都加入了额外的条件信息 ,使得生成过程更加可控,能够生成特定类别的样本,在实际应用中具有广泛的用途。这个条件信息可以是类别标签、文本描述或其他形式的辅助信息。
CGAN 原理及实现
- 一、CGAN 原理
-
- 1.1 基本概念
- 1.2 与传统GAN的区别
- 1.3 目标函数
- 1.4 损失函数
- 1.5 条件信息的融合方式
- 1.6 与其他GAN变体的对比
- 1.7 CGAN的应用
- 1.8 改进与变体
- 二、CGAN 实现
-
- 2.1 导包
- 2.2 数据加载和处理
- 2.3 构建生成器
- 2.4 构建判别器
- 2.5 训练和保存模型
- 2.6 绘制训练损失
- 2.7 图片转GIF
- 2.8 模型加载和生成
一、CGAN 原理
1.1 基本概念
条件生成对抗网络(Conditional GAN, CGAN)是GAN的一种扩展,它在生成器和判别器中都加入了额外的条件信息 yyy。这个条件信息可以是类别标签、文本描述或其他形式的辅助信息。
1.2 与传统GAN的区别
- 传统GAN: G(z)G(z)G(z) → 生成样本,D(x)D(x)D(x) → 判断真实/生成
CGAN: G(z∣y)G(z|y)G(z∣y) → 基于条件 yyy 生成样本,D(x∣y)D(x|y)D(x∣y) → 基于条件 yyy 判断真实/生成
1.3 目标函数
CGAN的目标函数可以表示为:minGmaxDV(D,G)=𝔼x∼pdata[logD(x∣y)]+𝔼z∼pz(z)[log(1−D(G(z∣y)∣y))]min_G max_D V(D,G) = 𝔼_{x \sim p_{\text{data}}}[log D(x|y)] + 𝔼_{z \sim p_z(z)}[log(1 - D(G(z|y)|y))]minGmaxDV(D,G)=Ex∼pdata[logD(x∣y)]+Ez∼pz(z)[log(1−D(G(z∣y)∣y))],其中 yyy 是条件信息。
1.4 损失函数
(1) 判别器(Discriminator)的损失函数
\space \space 判别器需要同时判断:
真实图像是否真实(且匹配其标签)生成图像是否虚假(且匹配其标签)
损失函数公式:
LD=Ex,y∼pdata[logD(x∣y)]⏟真实样本损失+Ez∼pz,y∼plabels[log(1−D(G(z∣y)∣y)]⏟生成样本损失\mathcal{L}_D = \underbrace{\mathbb{E}_{x,y \sim p_{\text{data}}}[\log D(x|y)]}_{\text{真实样本损失}} + \underbrace{\mathbb{E}_{z \sim p_z, y \sim p_{\text{labels}}}[\log (1 - D(G(z|y)|y)]}_{\text{生成样本损失}}LD=真实样本损失 Ex,y∼pdata[logD(x∣y)]+生成样本损失 Ez∼pz,y∼plabels[log(1−D(G(z∣y)∣y)]
(2)生成器(Generator)的损失函数
\space 生成器的目标是欺骗判别器,使其认为生成的图像是真实的(且匹配条件标签 yyy)。
损失函数公式:
LG=Ez∼pz,y∼plabels[log(1−D(G(z∣y)∣y)]⏟原始形式或−Ez,y[logD(G(z∣y)∣y)]⏟改进形式\mathcal{L}_G = \underbrace{\mathbb{E}_{z \sim p_z, y \sim p_{\text{labels}}}[\log (1 - D(G(z|y)|y)]}_{\text{原始形式}} \quad \text{或} \quad \underbrace{-\mathbb{E}_{z,y}[\log D(G(z|y)|y)]}_{\text{改进形式}}LG=原始形式 Ez∼pz,y∼plabels[log(1−D(G(z∣y)∣y)]或改进形式 −Ez,y[logD(G(z∣y)∣y)]
1.5 条件信息的融合方式
在损失计算中,条件标签通过以下方式参与:
- 生成器输入:噪声
z和标签y拼接后输入生成器gen_input = torch.cat([z, label_embed], dim=1) # z: [batch, z_dim], label_embed: [batch, embed_dim] - 判别器输入:图像和标签拼接后输入判别器
# 图像x: [batch, C, H, W], 标签扩展为 [batch, embed_dim, H, W] label_expanded = label_embed.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, H, W) disc_input = torch.cat([x, label_expanded], dim=1) # 沿通道维度拼接
1.6 与其他GAN变体的对比
| 损失函数特性 | 标准GAN | 条件GAN(CGAN) | WGAN-GP |
|---|---|---|---|
| 判别器输出 | 概率值 (0~1) | 条件概率值 | 未限制的分数 |
| 生成器目标 | 欺骗判别器 | 生成符合标签的图像 | 最小化Wasserstein距离 |
| 梯度稳定性 | 易崩溃 | 依赖条件强度 | 通过梯度惩罚稳定 |
1.7 CGAN的应用
- 图像生成:根据类别标签生成特定类型的图像
- 图像到图像转换:如将语义标签图转换为真实图像
- 文本到图像生成:根据文本描述生成图像
- 数据增强:为特定类别生成额外的训练样本
1.8 改进与变体
- AC-GAN:辅助分类器GAN,在判别器中增加分类任务
- InfoGAN:学习可解释的潜在表示
- StackGAN:分阶段生成高分辨率图像
- ProGAN:渐进式生成高分辨率图像
条件GAN通过引入条件信息,使得生成过程更加可控,能够生成特定类别的样本,在实际应用中具有广泛的用途。
二、CGAN 实现
2.1 导包
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np
import os
import time
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
# 判断是否存在可用的GPU
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 设置日志
time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) # 生成当前时间格式(例如:2024-03-15_14-30-00)
log_dir = os.path.join("./logs/cgan", time_str) # 设置日志路径,格式如:./logs/cgan/2024-03-15_14-30-00
os.makedirs(log_dir, exist_ok=True) # 自动创建目录
writer = SummaryWriter(log_dir=log_dir) # 初始化 SummaryWriter
os.makedirs("./img/cgan_mnist", exist_ok=True) # 存放生成样本目录
os.makedirs("./model", exist_ok=True) # 模型存放目录
2.2 数据加载和处理
# 加载 MNIST 数据集
def load_data(batch_size=64,img_shape=(1,32,32)):
transform = transforms.Compose([
transforms.Resize((img_shape[1],img_shape[2])),
transforms.ToTensor(), # 将图像转换为张量
transforms.Normalize(mean=[0.5], std=[0.5]) # 归一化到[-1,1]
])
# 下载训练集和测试集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 创建 DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=2,shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=2,shuffle=False)
return train_loader, test_loader
2.3 构建生成器
class Generator(nn.Module):
"""生成器"""
def __init__(self, img_shape=(1,32,32),latent_dim=100,num_classes=10,label_embed_dim=10):
"""
Args:
img_shape (int, optional):
生成图片大小,默认CHW=1*32*32
latent_dim (int, optional):
潜在噪声向量的维度。默认100维,作为生成器的随机输入种子。
num_classes (int, optional):
类别数量。默认10(例如MNIST的0-9数字分类)。
决定标签嵌入矩阵的行数。
label_embed_dim (int, optional):
标签嵌入向量的维度。默认10维。
将离散标签映射为连续向量的维度,影响条件信息的表达能力。
"""
super(Generator, self).__init__()
# 定义嵌入层 [batch_szie]-> [batch_size,label_embed_dim]=[64,10]
self.label_embed = nn.Embedding(num_classes, label_embed_dim) # num_classes 个类别, label_embed_dim 维嵌入
# 定义网络块
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))
return layers
# 定义模型架构
self.model = nn.Sequential(
*block(latent_dim + label_embed_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))), # [batch_size,1024]-> [batch_size,1*32*32]
nn.Tanh() # 输出归一化到[-1,1]
)
def forward(self, z, labels):
# 嵌入标签 [batch_size]-> [batch_size,label_embed_dim]=[64,10]
label_embed = self.label_embed(labels)
# 拼接嵌入标签和噪声 ->[batch_size,latent_dim + label_embed_dim]=[64,100+10]
gen_input = torch.cat([label_embed, z], dim=1)
# 生成图片-> [batch_size,C,H,W]=[64,1,32,32]
img = self.model(gen_input) # -> [batch_size,C*H*W]=[64,1*32*32]
img = img.view(img.shape[0], *img_shape) # [batch_size,C*H*W]-> [batch_size,C,H,W]=[64,1,32,32]
return img
2.4 构建判别器
class Discriminator(nn.Module):
"""判别器"""
def __init__(self, img_shape=(1,32,32),label_embed_dim=10):
super(Discriminator, self).__init__()
# 定义嵌入层 [batch_szie]-> [batch_size,label_embed_dim]=[64,10]
self.label_embed = nn.Embedding(num_classes, label_embed_dim) # num_classes 个类别, label_embed_dim 维嵌入
# 定义模型结构
self.model = nn.Sequential(
nn.Linear(label_embed_dim+ int(np.prod(img_shape)), 512), # [64,10+1*32*32]-> [64,512]
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(512, 1),
)
def forward(self, img, labels):
# 嵌入标签 [batch_size]-> [batch_size,label_embed_dim]=[64,10]
label_embed = self.label_embed(labels)
# 输入图片展平[64,1,32,32]-> [64,1*32*32]
img=img.view(img.shape[0], -1)
# 拼接嵌入标签和输入图片 ->[batch_size,label_embed_dim + C*H*W]=[64,10+1*32*32]
dis_input = torch.cat([label_embed, img], dim=1)
# 进行判定
validity = self.model(dis_input)
return validity # -> [64,1]
2.5 训练和保存模型
1. 定义保存生成样本
def sample_image(G,n_row, batches_done,latent_dim=100,device=device):
"""Saves a grid of generated digits ranging from 0 to n_classes"""
# 随机噪声-> [n_row ** 2,latent_dim]=[100,100]
z=torch.normal(0,1,size=(n_row ** 2,latent_dim),device=device) #从正态分布中抽样
# 条件标签->[100]
labels = torch.arange(n_row, dtype=torch.long, device=device).repeat_interleave(n_row)
gen_imgs = G(z, labels)
save_image(gen_imgs.data, "./img/cgan_mnist/%d.png" % batches_done, nrow=n_row, normalize=True)
2. 训练和保存
# 设置超参数
batch_size = 64
epochs = 200
lr= 0.0002
latent_dim=100 # 生成器输入噪声向量的长度(维数)
sample_interval=400 #每400次迭代保存生成样本
img_shape = (1,32,32) # 图片大小
num_classes=10 # 分类数
label_embed_dim=10 # 嵌入维数
# 加载数据
train_loader,_= load_data(batch_size=batch_size,img_shape=img_shape)
# 实例化生成器G、判别器D
G=Generator().to(device)
D=Discriminator().to(device)
# 设置优化器
optimizer_G = torch.optim.Adam(G.parameters(), lr=lr,betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=lr,betas=(0.5, 0.999))
# 损失函数
loss_fn=nn.BCEWithLogitsLoss()
# 开始训练
dis_costs,gen_costs = [],[] # 记录生成器和判别器每次迭代的开销(损失)
start_time = time.time() # 计时器
loader_len=len(train_loader) #训练集加载器的长度
for epoch in range(epochs):
# 进入训练模式
G.train()
D.train()
#记录生成器G和判别器D的总损失(1个 epoch 内)
gen_loss_sum,dis_loss_sum=0.0,0.0
loop = tqdm(train_loader, desc=f"第{epoch+1}轮")
for i, (real_imgs, real_labels) in enumerate(loop):
real_imgs=real_imgs.to(device) # [B,C,H,W]
real_labels=real_labels.to(device) # [B]
# 平滑真假标签,2维[B,1]
valid_labels = torch.empty(real_imgs.shape[0], 1, device=device).uniform_(0.9, 1.0).requires_grad_(False) # 替代1.0
fake_labels = torch.empty(real_imgs.shape[0], 1, device=device).uniform_(0.0, 0.1).requires_grad_(False) # 替代0.0
# -----------------
# 训练生成器
# -----------------
# 获取噪声样本[batch_size,latent_dim]及对应的条件标签 [batch_size]
z=torch.normal(0,1,size=(real_imgs.shape[0],latent_dim),device=device) #从正态分布中抽样
gen_labels = torch.randint(0, num_classes, (real_imgs.shape[0],), device=device, dtype=torch.long) # 0~9整数之间,随机抽 real_imgs.shape[0]次
# 计算生成器损失
gen_imgs=G(z,gen_labels)
gen_loss=loss_fn(D(gen_imgs,gen_labels),valid_labels)
# 更新生成器参数
optimizer_G.zero_grad() #梯度清零
gen_loss.backward() #反向传播,计算梯度
optimizer_G.step() #更新生成器
# -----------------
# 训练判别器
# -----------------
# 计算判别器损失
# Step-1:对真实图片损失
valid_loss=loss_fn(D(real_imgs,real_labels),valid_labels)
# Step-2:对生成图片损失
fake_loss=loss_fn(D(gen_imgs.detach(),gen_labels),fake_labels)
# Step-3:整体损失
dis_loss=(valid_loss+fake_loss)/2.0
# 更新判别器参数
optimizer_D.zero_grad() #梯度清零
dis_loss.backward() #反向传播,计算梯度
optimizer_D.step() #更新判断器
# 对生成器和判别器每次迭代的损失进行累加
gen_loss_sum+=gen_loss
dis_loss_sum+=dis_loss
gen_costs.append(gen_loss.item())
dis_costs.append(dis_loss.item())
# 每 sample_interval 次迭代保存生成样本
batches_done = epoch * loader_len + i
if batches_done % sample_interval == 0:
sample_image(G=G,n_row=10, batches_done=batches_done)
# 更新进度条
loop.set_postfix(mean_gen_loss=f"{gen_loss_sum/(loop.n + 1):.8f}",mean_dis_loss=f"{dis_loss_sum/(loop.n + 1):.8f}")
writer.add_scalars(
main_tag="Train Losses",
tag_scalar_dict={
"Generator": gen_loss,
"Discriminator": dis_loss,
},
global_step=batches_done # X轴坐标
)
writer.close()
print('总共训练用时: %.2f min' % ((time.time() - start_time)/60))
#仅保存模型的参数(权重和偏置),灵活性高,可以在不同的模型结构之间加载参数
torch.save(G.state_dict(), "./model/CGAN_G.pth")
torch.save(D.state_dict(), "./model/CGAN_D.pth")
2.6 绘制训练损失
# 创建画布
plt.figure(figsize=(10, 5))
ax1 = plt.subplot(1, 1, 1)
# 绘制曲线
ax1.plot(range(len(gen_costs)), gen_costs, label='Generator loss', linewidth=2)
ax1.plot(range(len(dis_costs)), dis_costs, label='Discriminator loss', linewidth=2)
ax1.set_xlabel('Iterations', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('CGAN Training Loss', fontsize=14)
ax1.legend(fontsize=10)
ax1.grid(True, linestyle='--', alpha=0.6)
ax2 = ax1.twiny() # 创建共享Y轴的新X轴
newlabel = list(range(epochs+1)) # Epoch标签 [0,1,2,...]
iter_per_epoch = len(train_loader) # 每个epoch的iteration次数
newpos = [e*iter_per_epoch for e in newlabel] # 计算Epoch对应的iteration位置
ax2.set_xticks(newpos[::10])
ax2.set_xticklabels(newlabel[::10])
ax2.xaxis.set_ticks_position('bottom')
ax2.xaxis.set_label_position('bottom')
ax2.spines['bottom'].set_position(('outward', 45)) # 坐标轴下移45点
ax2.set_xlabel('Epochs') # 设置轴标签
ax2.set_xlim(ax1.get_xlim()) # 与主X轴范围同步
plt.tight_layout()
plt.savefig('cgan_loss.png', dpi=300)
plt.show()
2.7 图片转GIF
from PIL import Image
def create_gif(img_dir="./img/cgan_mnist", output_file="./img/cgan_mnist/cgan_figure.gif", duration=100):
images = []
img_paths = [f for f in os.listdir(img_dir) if f.endswith(".png")]
img_paths_sorted = sorted(
img_paths,
key=lambda x: (
int(x.split('.')[0]), # (如 400.png 的 400)
)
)
for img_file in img_paths_sorted:
img = Image.open(os.path.join(img_dir, img_file))
images.append(img)
images[0].save(output_file, save_all=True, append_images=images[1:],
duration=duration, loop=0)
print(f"GIF已保存至 {output_file}")
create_gif()
2.8 模型加载和生成
#载入训练好的模型
G = Generator() # 定义模型结构
G.load_state_dict(torch.load("./model/CGAN_G.pth",weights_only=True,map_location=device)) # 加载保存的参数
G.to(device) # 将模型移动到设备(GPU 或 CPU)
G.eval() # 将模型设置为评估模式
# 获取噪声样本[10,100]及对应的条件标签 [10]
z=torch.normal(0,1,size=(10,100),device=device) #从正态分布中抽样
gen_labels = torch.arange(10, dtype=torch.long, device=device) #0~9整数
#生成假样本
gen_imgs=G(z,gen_labels).view(-1,32,32) # 4维->3维
gen_imgs=gen_imgs.detach().cpu().numpy()
# #绘制
plt.figure(figsize=(3, 2))
for i in range(10):
plt.subplot(2, 5, i + 1)
plt.xticks([], [])
plt.yticks([], [])
plt.imshow(gen_imgs[i], cmap='gray')
plt.title(f"Figure {i}", fontsize=5)
plt.tight_layout()
plt.show()

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐



所有评论(0)