对抗生成网络_GAN(生成对抗网络)学习笔记
大纲生成对抗网络(GAN)Generative Adversarial Networks第一作者Ian Goodfellow目前引用量1w8+GAN主要用于样本生成GAN由生成器和判别器组成:生成器的功能是输入一个样本将其输出成一个逼真的样子判别器来判断输入的样本是真的还是伪造的。原理生成对抗网络(GAN)最小二乘GAN(LSGAN)Conditional GAN(CGAN)实践计图的安...

























大纲
生成对抗网络(GAN)
Generative Adversarial Networks
第一作者Ian Goodfellow
目前引用量1w8+
GAN主要用于样本生成
GAN由生成器和判别器组成:
生成器的功能是输入一个样本将其输出成一个逼真的样子
判别器来判断输入的样本是真的还是伪造的。
原理
生成对抗网络(GAN)
最小二乘GAN(LSGAN)
Conditional GAN(CGAN)
实践
计图的安装与使用
LSGAN 训练与生成
CGAN 训练与生成
模型迁移:计图辅助转换工具
多机多卡:计图分布式接口
计图依赖OpenMPI,使用如下命令安装OpenMPI:
sudo apt install openmpi-bin openmpi-common libopenmpi-dev
计图会自动检测环境变量中是否包含mpicc,如果计图成功的检测到了mpicc,那么会输出如下信息:
[i 0502 14:09:55.758481 24 __init__.py:203] Found mpicc(1.10.2) at /usr/bin/mpicc
如果计图没有在环境变量中找到mpi,手动指定mpicc的路径告诉计图,添加环境变量:
export mpicc_path=/you/mpicc/path
计图分布式原理
单卡训练代码
python3.7 -m jittor.test.test_resnet
分布式多卡训练代码
mpirun -np 4 python3.7 -m jittor.test.test_resnet
指定特定显卡的多卡训练代码
CUDA_VISIBLE_DEVICES="2,3" mpirun -np 2 python3.7 –m jittor.test.test_resnet
我这次DIY实验,实现了CGAN生成数字,训练器和生成器的代码段如下:
# ----------# Training# ----------for epoch in range(opt.n_epochs): for i, (imgs, labels) in enumerate(dataloader): batch_size = imgs.shape[0] # Adversarial ground truths valid = jt.ones([batch_size, 1]).float32().stop_grad() fake = jt.zeros([batch_size, 1]).float32().stop_grad() # Configure input real_imgs = jt.array(imgs) labels = jt.array(labels) # ----------------- # Train Generator # ----------------- # Sample noise and labels as generator input z=jt.array(np.random.normol(0,1,(batch_size, opt.latent_dim))).float32() #sample noise-随机一维噪声z gen_labels=jt.array(np.random.randint(0,opt.n_classes,batch_size)).float32() #labels类别标签 # Generate a batch of images gen_imgs=generator(z,gen_labels) # Loss measures generator's ability to fool the discriminator validity=dsicriminator(gen_imgs,gen_labels) g_loss=adversarial_loss(validity,valid) g_loss.sync() optimizer_G.step(g_loss) # --------------------- # Train Discriminator # --------------------- # - 尽可能识别real_imgs为valid # - 尽可能识别gen_imgs为fake # Loss for real images validity_real=discriminator(real_imgs,labels) d_real_loss=adversarial_loss(validity_real,valid) # Loss for fake images validity_fake=discriminator(gen_imgs.stop_grad(),gen_labels) # d_fake_loss=adversarial_loss(validity_fake,fake) # # Total discriminator loss d_loss = (d_real_loss + d_fake_loss) / 2 d_loss.sync() optimizer_D.step(d_loss) if i % 50 == 0: print( "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, opt.n_epochs, i, len(dataloader), d_loss.data, g_loss.data) ) batches_done = epoch * len(dataloader) + i if batches_done % opt.sample_interval == 0: sample_image(n_row=10, batches_done=batches_done) if epoch % 10 == 0: generator.save("saved_models/generator_last.pkl") discriminator.save("saved_models/discriminator_last.pkl")!
实验结果图:

再见!
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐
所有评论(0)