20.生成对抗网络


本课程来自深度之眼deepshare.net,部分截图来自课程视频。
Chloe H. 提供:GAN训练的tip,
https://chloes-dl.com/2019/11/19/tricks-and-tips-for-training-a-gan/

生成对抗网络(GAN)是什么?

GAN(Generative Adversarial Nets):生成对抗网络——一种可以生成特定分布数据的模型
文献:Generative Adversarial Nets. Ian Goodfellow. 2014
在这里插入图片描述
在这里插入图片描述
堃哥说:
Adversarial training is the coolest thing since sliced bread.I’ ve listed a bunch of relevant papers in a previous answer. Expect more impressive results with this technique in the coming years.
下面是用GAN生成的64张人脸。有些比较畸形。。。
在这里插入图片描述

inference

1.输入:
用高斯分布随机的采样一些噪声
2.构建模型,加载参数:
这里注意,模型inference时只用到了Generator(生成器),不需要Discriminator(判别器)
3.inference,把输入放到Generator中就可以生成虚假数据。
fake_data=net_g(fixed_noise). detach(). cpu()

GAN网络结构

以下三个图片(每个图片都是讲的GAN结构)分别来自:
《Recent Progress on Generative Adversarial Networks(GANs):A Survey》
《How Generative Adversarial Networks and Its Variants Work:An Overview of GAN》
《Generative Adversarial Networks_A Survey and Taxonomy》
G代表生成器,D代表判别器,z是输入向量,输入向量通过生成器后,得到一个生成的结果,如果是人脸图片生成,这个的G(z)就是一个图片tensor,然后结合训练数据x,通过判别器给出图片是真还是假(D是二分类网络。)
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

如何训练GAN?

训练目的
1.对于D:对真样本输出高概率
2.对于G:输出使D会给出高概率的数据
在这里插入图片描述
GAN的训练模式与监督学习训练模式不一样的地方:需要注意的是,监督学习中损失函数的目标是让模型的输出值尽量的逼近真实值;在GAN中输出值不是逼近真实值,而是使得输出值的分布接近真实值的分布。
在这里插入图片描述
下面看具体步骤,二次元警告。。。。(李宏毅的笔记里面也有相应内容)
step1:训练D
输入:真实数据加G生成的假数据
输出:二分类概率
在这里插入图片描述
上图中是更新一次D的过程
step2:训练G
输入:随机噪声z
输出:分类概率——D(G(z))
在这里插入图片描述
上图中输出如果是0.13,那么差异为1-0.13,我们的目标是D输出的目标概率是越高越好,最好就是1,这里只有0.13,说明还不够好,需要继续训练G。
然后回到step1继续循环,知道满足收敛条件。
下面对GAN论文中对算法的文字进行一些解释
在这里插入图片描述
1.整个算法是一个大的for循环,可以根据图中的最长的横线分为两个部分,上面部分是训练判别器的,下面部分是训练生成器的。
2.先看训练判别器部分,这个部分是有一个for循环包围着的(1号箭头),这个是早期GAN的设置,意思是先要通过几次迭代训练几次判别器,后来经过实践证明,这里实际上是不需要的,只用训练一次就ok了,所以这里的循环次数k我们可以设置为1。
3.在训练判别器时,先分别从噪声和真实数据中进行采样,然后计算损失函数,注意在更新损失函数,用的是ascending梯度,原因分析:损失函数有两项,第一项是真实数据,我们希望这个的概率是越大越好(2号箭头),第二项是虚假数据,这个概率我们希望是越小越好,但是又有一个1-这一项,所以整个第二项也是越大越好(3号箭头),整体更新是变大的趋势,所以用的随机梯度上升法。
4.训练生成器部分,先从噪声中采样(这里的采样数据可以和上面部分的相同,也可以不同,感觉可以这样是因为我们在乎的是数据的分布,而不是具体的数据)
5.同理,生成器希望这个损失函数的值通过判别器判别后是真实数据(生成器要骗过判别器),所以 D ( G ( z ( i ) ) ) D(G(z^{(i)})) D(G(z(i)))这项是越大越好(4号箭头),则整体是越小越好(5号箭头)。因此在生成器部分用的是随机梯度下降法
6.可以看出,由于是对抗,在设计损失函数的时候,一个是梯度上升,一个是梯度下降;另外两个损失函数有一项是一样的,看图中绿线部分。

训练DCGAN实现人脸生成

《Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks》

Generator:卷积结构的模型

在这里插入图片描述
输入是100维的随机噪声,然后通过transpose的卷积生成一个64643的rgb图片
注意:输入在pytorch中用tensor表示为:(1,100,1,1)
第一个1 是batch,后面两个1是高和宽。在这里插入图片描述

Discriminator:卷积结构的模型

老师很懒,直接把上面的结构旋转180度,输入是64643的rgb图像,不过输出是二分类。
在这里插入图片描述
DCGAN实现人脸生成
数据:CelebA人脸数据。
数据项目:http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
不是用的原项目的人脸,而是用矫正过的。
22万人脸矫正图:
https://pan.baidu.com/s/1JDrl82vTjgFsmKQ0SPNtzA 密码:41g7 失效
矫正前:
在这里插入图片描述
人脸所在位置以及比例都不确定
矫正后,是通过五个人脸关键点(中心化)以及人脸所占比例进行了矫正:
在这里插入图片描述
构建transform的时候,需要把数据尺度变换到-1~ 1区间,因为随机采用的生成器的值也是这个区间,所以这里不追求0均值的分布,而是追求区间一致。

生成器的超参数初始化代码:

class Generator(nn.Module):
    def __init__(self, nz=100, ngf=128, nc=3):#输入的维度是100,特征图数量是128,输出是3d张量
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),#ngf * 8=1024,对应到结构图中的一个卷积模块
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

二分类的loss

# step3: loss 
criterion=nn. BCELoss()

判别器和生成器的训练迭代过程的代码

############################
            # (1) Update D network
            ###########################

            net_d.zero_grad()

            # create training data
            real_img = data.to(device)
            b_size = real_img.size(0)
            real_label = torch.full((b_size,), real_idx, device=device)#real_idx是真实图片的lable

            noise = torch.randn(b_size, nz, 1, 1, device=device)#输入是4d张量,第一个维度是batchsize,nz是100维
            fake_img = net_g(noise)
            fake_label = torch.full((b_size,), fake_idx, device=device)#fake_idx是假图片lable

            # train D with real img
            out_d_real = net_d(real_img)
            loss_d_real = criterion(out_d_real.view(-1), real_label)

            # train D with fake img
            out_d_fake = net_d(fake_img.detach())
            loss_d_fake = criterion(out_d_fake.view(-1), fake_label)

            # backward
            loss_d_real.backward()
            loss_d_fake.backward()
            loss_d = loss_d_real + loss_d_fake

            # Update D
            optimizerD.step()

            # record probability
            d_x = out_d_real.mean().item()      # D(x)
            d_g_z1 = out_d_fake.mean().item()   # D(G(z1))
            
            #以上完成一次判别器的更新

            ############################
            # (2) Update G network
            ###########################
            net_g.zero_grad()

            label_for_train_g = real_label  # 1
            out_d_fake_2 = net_d(fake_img)

            loss_g = criterion(out_d_fake_2.view(-1), label_for_train_g)
            loss_g.backward()#只更新生成器,不改变判别器
            optimizerG.step()#

            # record probability
            d_g_z2 = out_d_fake_2.mean().item()  # D(G(z2))

            # Output training stats
            if i % 10 == 0:
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                      % (epoch, num_epochs, i, len(train_loader),
                         loss_d.item(), loss_g.item(), d_x, d_g_z1, d_g_z2))

            # Save Losses for plotting later
            G_losses.append(loss_g.item())
            D_losses.append(loss_d.item())

训练过程中的注意事项:
1.特征图数量ngf是原始模型128,如果改为64,效果会变差,但是训练速度快一些
2.标签值的平滑处理,这里用的是1和0,可以平滑为:0.9和0.1
GAN的应用
https://medium.com/@jonathan_hui/gan-some-cool-applications-of-gans-4c9ecca35900(失效)
GAN的应用:《CycleGAN》
在这里插入图片描述
GAN的应用:《PixelDTGAN》
在这里插入图片描述
GAN的应用:《SRGAN》
在这里插入图片描述
GAN的应用:
Progressive GAN
在这里插入图片描述
GAN的应用:
《StackGAN》根据文本生成图片
在这里插入图片描述
GAN的应用:
《Context Encoders》
在这里插入图片描述
GAN的应用:
《Pix2Pix》
在这里插入图片描述
GAN的应用:
《ICGAN》
在这里插入图片描述GAN推荐github:https://github.com/nightrome/really-awesome-gan

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

oldmao_2000

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值