对抗生成网络_GAN(生成对抗网络)学习笔记

本文深入探讨了生成对抗网络(GAN)的基本原理,包括最小二乘GAN(LSGAN)和条件GAN(CGAN)。通过计图工具,介绍了如何进行LSGAN和CGAN的训练及图像生成,并详细阐述了计图的分布式接口和多卡训练过程。文中还分享了DIY的CGAN实验,用于生成数字,并展示了实验结果。
摘要由CSDN通过智能技术生成

ed63b5243c69cb6221a31da547246e60.png

1bae4f0070034ba5d57362ca7f0957b7.png

07d9e0393d682e40f5e165d43afabecd.png

2379dd566a97e3553e3e199c7fdce1f1.png

20e9d8f693d678a22b08d5840ff3a2b3.png

c435c7752ee112eed309ac54de9141d7.png

a04e8b001498628332230ca0fa275c05.png

b458c34906a81af94be698912190952a.png

dfbfcec7d4d9c7ae0e75ae086de4875c.png

509f867c72e98db97cec2d4913c97873.png

13259cce0897c6c305b8610803acd896.png

6b2a7d2c00e0aaca6310d7278e62408a.png

7277b024202ca60e0a232e8d7fbbd386.png

ea535b5258a58c778216419593984b0f.png

cad57d0f51080861fdb387c8e77e600f.png

e3f6afb7c9639b57966cbe95dc8194ac.png

4ed41bcb525c31a3f7fbdb48dd3fd8ec.png

7ada5e8a68c50a3a970c5d67f550bbac.png

d41cfd45c5e8fc61f42185985507b6ff.png

f5414cf59c75365e8b45be03f1dd31a8.png

230a3907256fcafe59a55f773a6c101f.png

0178865fbf8ced3e448ccab6f0355f18.png

5b34d515038186392efbea9042aab70d.png

043737048421d324301a2425176047a3.png

3c575d84d0a1a4efa045b816fe1ac58e.png

大纲

生成对抗网络(GAN)

Generative Adversarial Networks

第一作者Ian Goodfellow

目前引用量1w8+

GAN主要用于样本生成

GAN由生成器判别器组成:

生成器的功能是输入一个样本将其输出成一个逼真的样子

判别器来判断输入的样本是真的还是伪造的。

原理

生成对抗网络(GAN)

最小二乘GAN(LSGAN)

Conditional GAN(CGAN)

实践

计图的安装与使用

LSGAN 训练与生成

CGAN 训练与生成

模型迁移:计图辅助转换工具

多机多卡:计图分布式接口

计图依赖OpenMPI,使用如下命令安装OpenMPI:

sudo apt install openmpi-bin openmpi-common libopenmpi-dev

计图会自动检测环境变量中是否包含mpicc,如果计图成功的检测到了mpicc,那么会输出如下信息:

[i 0502 14:09:55.758481 24 __init__.py:203] Found mpicc(1.10.2) at /usr/bin/mpicc

如果计图没有在环境变量中找到mpi,手动指定mpicc的路径告诉计图,添加环境变量:

export mpicc_path=/you/mpicc/path

计图分布式原理

单卡训练代码

python3.7 -m jittor.test.test_resnet

分布式多卡训练代码

mpirun -np 4 python3.7 -m jittor.test.test_resnet

指定特定显卡的多卡训练代码

CUDA_VISIBLE_DEVICES="2,3" mpirun -np 2 python3.7 –m jittor.test.test_resnet

我这次DIY实验,实现了CGAN生成数字,训练器和生成器的代码段如下:

# ----------#  Training# ----------for epoch in range(opt.n_epochs):    for i, (imgs, labels) in enumerate(dataloader):        batch_size = imgs.shape[0]        # Adversarial ground truths        valid = jt.ones([batch_size, 1]).float32().stop_grad()        fake = jt.zeros([batch_size, 1]).float32().stop_grad()        # Configure input        real_imgs = jt.array(imgs)        labels = jt.array(labels)        # -----------------        #  Train Generator        # -----------------        # Sample noise and labels as generator input             z=jt.array(np.random.normol(0,1,(batch_size, opt.latent_dim))).float32()    #sample noise-随机一维噪声z             gen_labels=jt.array(np.random.randint(0,opt.n_classes,batch_size)).float32()  #labels类别标签        # Generate a batch of images              gen_imgs=generator(z,gen_labels)          # Loss measures generator's ability to fool the discriminator              validity=dsicriminator(gen_imgs,gen_labels)              g_loss=adversarial_loss(validity,valid)              g_loss.sync()              optimizer_G.step(g_loss)        # ---------------------        #  Train Discriminator        # ---------------------        #       - 尽可能识别real_imgs为valid        #       - 尽可能识别gen_imgs为fake        # Loss for real images            validity_real=discriminator(real_imgs,labels)            d_real_loss=adversarial_loss(validity_real,valid)                    # Loss for fake images             validity_fake=discriminator(gen_imgs.stop_grad(),gen_labels)    #             d_fake_loss=adversarial_loss(validity_fake,fake)                 #        # Total discriminator loss            d_loss = (d_real_loss + d_fake_loss) / 2            d_loss.sync()            optimizer_D.step(d_loss)                if i  % 50 == 0:            print(                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"                % (epoch, opt.n_epochs, i, len(dataloader), d_loss.data, g_loss.data)            )        batches_done = epoch * len(dataloader) + i        if batches_done % opt.sample_interval == 0:            sample_image(n_row=10, batches_done=batches_done)    if epoch % 10 == 0:        generator.save("saved_models/generator_last.pkl")        discriminator.save("saved_models/discriminator_last.pkl")!

实验结果图:

3445f6fd5cb2ccf54cf0b6fd15c863ab.png

再见!

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值