【pytorch】使用生成对抗网络GAN实现MINIST手写数字图像生成
本文回顾了作者从零开始构建、调试和优化一个基础生成对抗网络(GAN)的全过程,目标是利用经典的 MNIST 手写数字数据集来训练一个能生成逼真数字图像的模型。
前言
您是否曾对AI能够创作出以假乱真的图像而感到惊叹?无论是生成风景、人脸还是艺术画作,这些令人印象深刻的应用背后,常常有一种迷人而强大的深度学习模型在工作,它就是生成对抗网络(Generative Adversarial Network, GAN)。
GAN的理念天才而直观。想象一下,我们同时训练两个相互竞争的神经网络:
-
一位是伪造者(生成器 Generator),它的任务是从一堆随机噪声开始,学习画出逼真的“赝品”,比如我们这次要挑战的手写数字。
-
另一位是鉴定师(判别器 Discriminator),它的任务是“火眼金睛”,尽可能准确地分辨出哪些是来自真实数据集的“真品”,哪些是“伪造者”画的赝品。
这两个网络在“对抗”中共同进化:伪造者为了骗过鉴定师而画得越来越好,鉴定师为了不被欺骗而看得越来越准。最终,当这个游戏达到一种精妙的平衡时,我们就得到了一个技艺高超的“伪造大师”——一个能够生成全新、逼真图像的生成器。
在这篇博客中,我将带您回顾我从零开始构建、调试并优化一个基础GAN网络的全过程。我们将使用经典的手写数字数据集MNIST,亲眼见证模型如何从最初的随机噪声,逐渐学会生成清晰、多样的数字图片。我们还会一起探讨训练过程中遇到的经典问题,如“模式崩溃”,并分享如何通过调整损失函数和超参数来解决它们,最终达到理想的生成效果。
项目实战
1.1导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import random
1.2参数设置
# 设置一个固定的种子值
seed = 42
# 为 Python 内置的 random 模块设定种子
random.seed(seed)
# 为 NumPy 设定种子
np.random.seed(seed)
#为GPU,CPU设置种子
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
#定义超参数
noise_dim = 100
noise_dim 的全称是 Noise Dimension,即噪声维度。
简单来说,它定义了提供给生成器(Generator)的初始“灵感”的复杂度。
noise_dim = 100
noise_dim=100 意味着模型设计者给生成器分配了 100 种核心的、潜在的特征来组合成最终的图像。
1.3使用gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1.4数据处理
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
1.5加载数据集
mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(mnist_dataset, batch_size=64, shuffle=True, num_workers=4)
在这里我们给num_workers设置为4,意味着框架将创建 4 个独立的子进程 (subprocesses) 来执行数据准备任务。
1.6定义模型
1.6.1创建生成器
# 创建生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
# 定义每个block的结构
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(noise_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, 28 * 28),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(-1, 1, 28, 28)
return img
# 定义每个block的结构
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
功能: 这个函数用于创建一个标准的网络层“块”。这个“块”由三个部分组成:一个线性层、一个可选的批归一化层和一个激活函数层。
参数:
-
in_feat: 输入特征的数量(前一层的神经元数量)。 -
out_feat: 输出特征的数量(当前层的神经元数量)。 -
normalize=True: 一个布尔值开关,决定是否要添加批归一化(Batch Normalization)层。默认是添加。
内部逻辑:
-
layers = [nn.Linear(in_feat, out_feat)]: 创建一个列表,并首先加入一个全连接层(nn.Linear)。它将输入数据从in_feat维度线性变换到out_feat维度。 -
if normalize: layers.append(nn.BatchNorm1d(out_feat, 0.8)): 如果normalize为True,则添加一个一维批归一化层(nn.BatchNorm1d)。-
作用: 批归一化可以加速模型训练,提高稳定性,并有轻微的正则化效果。它通过重新缩放和中心化上一层的输出来实现。
-
0.8是momentum参数,用于计算运行中的均值和方差。
-
-
layers.append(nn.LeakyReLU(0.2, inplace=True)): 添加Leaky ReLU激活函数。-
作用: Leaky ReLU 允许在输入为负时有一个小的、非零的梯度(这里是0.2),这有助于解决标准ReLU函数中可能出现的“神经元死亡”问题。
-
inplace=True: 这是一个内存优化选项,它会直接修改输入数据,而不是创建一个新的输出对象,从而节省内存。
-
-
return layers: 函数返回一个包含这些层定义的列表。
self.model = nn.Sequential(
*block(noise_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, 28 * 28),
nn.Tanh()
)
这里是生成器网络的主体结构,使用了 nn.Sequential 容器。
-
nn.Sequential: 这是一个序列容器,它会将你传入的模块按照顺序连接起来。数据会依次通过这些层。 -
*block(...): 这里的星号*是Python的解包操作符。因为block函数返回的是一个层的列表(例如[nn.Linear, nn.BatchNorm1d, nn.LeakyReLU]),星号*会将这个列表“打散”成独立的元素,然后作为参数传递给nn.Sequential。-
例如,
*block(128, 256)等价于nn.Linear(128, 256), nn.BatchNorm1d(256, 0.8), nn.LeakyReLU(0.2, inplace=True)。
-
-
网络流程:
-
*block(noise_dim, 128, normalize=False):-
输入层。它接收一个维度为
noise_dim的噪声向量。 -
将其扩展到
128维。 -
normalize=False:通常输入层之后不加批归一化。
-
-
*block(128, 256),*block(256, 512),*block(512, 1024):-
这是三个隐藏层。它们逐步将特征维度从
128放大到1024。这是一个典型的生成器结构,通过逐层放大来学习更复杂的特征表示。
-
-
nn.Linear(1024, 28 * 28):-
输出层。将
1024维的特征向量映射到一个28 * 28 = 784维的向量。这个向量的每一个元素都对应最终生成图像的一个像素点。
-
-
nn.Tanh():-
输出激活函数。Tanh 函数将输出值压缩到
[-1, 1]的范围内。 -
这在GAN中非常常用,因为通常我们会将输入的真实图像像素值也归一化到
[-1, 1]这个范围,以匹配生成器的输出。
-
-
def forward(self, z):
img = self.model(z)
img = img.view(-1, 1, 28, 28)
return img
-
def forward(self, z): 定义前向传播函数,输入参数z就是我们前面提到的随机噪声向量。它的形状通常是[batch_size, noise_dim]。 -
img = self.model(z): 将噪声向量z传入我们之前用nn.Sequential定义好的模型中。-
经过
self.model的计算后,img的形状会是[batch_size, 784]。它是一个扁平化的向量。
-
-
img = img.view(-1, 1, 28, 28): 这是非常关键的一步,用于重塑(reshape)张量的形状。-
.view()函数将扁平的向量img转换成图像的格式。 -
-1: 表示这个维度的大小由PyTorch自动推断。在这里,它就是batch_size。 -
1: 表示图像的通道数。因为目标是生成灰度图,所以通道数是1。 -
28,28: 分别表示图像的高度(Height)和宽度(Width)。 -
最终,输出的
img张量形状为[batch_size, 1, 28, 28],这是一个标准的PyTorch图像格式(BCHW:批量大小,通道,高度,宽度)。
-
-
return img: 返回最终生成的、具有图像形状的张量。
总结
Generator(生成器)实现了一个完整的生图功能,通过接受低维的随机噪声向量,经过block的学习,然后生成输出一个一维向量,并重塑为标准图像格式,以便后续处理。
1.6.2创建判别器
判别器是生成对抗网络 (GAN) 中的另一半核心组件。如果说生成器 (Generator) 是一个“伪造者”或“艺术家” 👨🎨,那么判别器 (Discriminator) 就是一个“鉴定师”或“警察” 👮♀️。
它的核心任务是:接收一张图片,然后判断这张图片是“真的”(来自真实数据集)还是“假的”(由生成器伪造)。
# 创建判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
x = torch.flatten(x, 1)
x = self.model(x)
return x
判别器的结构,刚好与生成器相反,它是将一张高维图片压缩成一个低维的判断结果。
1.7初始化配置
adversarial_loss = nn.BCELoss().to(device)
generator = Generator().to(device)
discriminator = Discriminator().to(device)
lr_G = 0.001
lr_D = 0.0005
optimizer_G = optim.Adam(generator.parameters(), lr=lr_G, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr_D, betas=(0.5, 0.999))
adversarial_loss = nn.BCELoss().to(device)
nn.BCELoss(): 这行代码定义了“对抗损失”的计算方式。BCE 的全称是 Binary Cross-Entropy Loss (二元交叉熵损失)。
-
作用: 这是专门用于二分类问题的损失函数。在 GAN 中,判别器(Discriminator)的任务恰好就是一个二分类问题:它需要判断一张图片是真的(类别1)还是假的(类别0)。
-
工作原理:
BCELoss会比较判别器的输出(一个0到1之间的概率)和真实的标签(1代表真,0代表假)。如果判别器猜错了,比如把一张真图片判断为假的(概率接近0),BCELoss就会计算出一个较大的损失值(惩罚)。反之,如果猜对了,损失值就很小。这个损失值就是模型学习和优化的依据。
generator = Generator().to(device)
discriminator = Discriminator().to(device)
实例化模型,并加载到GPU上。
lr_G = 0.001
lr_D = 0.0005
定义生成器和判别器的学习率,将判别器的学习率设定的更低,是为了让生成器能更好的跟上,保持训练时的一个平衡。
optimizer_G = optim.Adam(generator.parameters(), lr=lr_G, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr_D, betas=(0.5, 0.999))
创建2个优化器。
1.8开始训练
# 设定要展示的图片数量和固定的噪声
num_samples_to_show = 25
fixed_z = torch.randn(num_samples_to_show, noise_dim, device=device)
n_epochs = 150
for epoch in range(n_epochs):
generator.train()
discriminator.train()
generator_loss_acc, discriminator_loss_acc = 0, 0
for batch_idx, (imgs, _) in enumerate(train_loader):
real_labels = torch.ones(imgs.shape[0], 1, device=device)
fake_labels = torch.zeros(imgs.shape[0], 1, device=device)
real_imgs = imgs.to(device)
# --- 训练生成器 ---
optimizer_G.zero_grad()
z = torch.randn(imgs.shape[0], noise_dim, device=device)
gen_imgs = generator(z)
g_loss = adversarial_loss(discriminator(gen_imgs), real_labels)
g_loss.backward()
optimizer_G.step()
# --- 训练判别器 ---
optimizer_D.zero_grad()
real_loss = adversarial_loss(discriminator(real_imgs), real_labels)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake_labels)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
# 累加损失
generator_loss_acc += g_loss.item()
discriminator_loss_acc += d_loss.item()
# 打印每个epoch的平均损失
print(f"Epoch: {epoch+1}/{n_epochs}, Generator Loss: {generator_loss_acc / len(train_loader):.4f}, Discriminator Loss: {discriminator_loss_acc / len(train_loader):.4f}")
# --- 每个epoch结束后,生成并显示图片 ---
generator.eval()
with torch.no_grad():
gen_imgs = generator(fixed_z).detach().cpu()
gen_imgs = gen_imgs * 0.5 + 0.5
fig, axs = plt.subplots(5, 5, figsize=(8, 8))
for i in range(5):
for j in range(5):
axs[i, j].imshow(gen_imgs[i*5+j].squeeze(), cmap='gray')
axs[i, j].axis('off')
plt.suptitle(f"Epoch {epoch+1}", y=0.92)
plt.show()
print("--------- Training Finished ---------")
fixed_z = torch.randn(num_samples_to_show, noise_dim, device=device)
创建固定噪声向量。
后面的流程就是分别训练生成器和判别器了,我就不在过多叙述了,然后每个epoch结束后,都会打印出生成的图片。

总结
本人水平有限,文中若有疏漏或错误之处,恳请各位不吝指正。
欢迎与我交流探讨,共同进步。
所附代码仅为示例实现,未必是最优解,大家可根据实际需求自由调整优化。
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐
所有评论(0)