AI入门----神经网络实战----GAN

前言
之前我们学了AutoEncoder生成式神经网络。从效果看,生成的图片有点模糊。这里我们将介绍另一种生成式神经网络:GAN。这可能是深度学习中最让人兴奋的一个方向了。我们在这里做一个形象的比方:AutoEncoder就像是一个圣人,他告诉大家该怎么做事,怎么说话,甚至怎么走路。所以大家都向圣人学习。最后人人都学的类似于圣人。这种方式的优点是学习的速度相对比较快。但是随着时代的发展,我们可以看到原先的圣人并不是什么都是对的,圣人也有错误和缺陷。所以,这种方式的缺点是很难得到最好的结果 (因为没有人知道最好的结果是什么)。
GAN网络像是一个游戏。游戏中有一个想努力改进自己的人,他叫做G,但是他不知道怎么改进自己。G请了一个批评家D来指出他的缺点。开始的时候,G和D的水平都很差,G有100个缺点,D最多只能找出1~2个,后来G和D一起成长。G做的越来越好,D的鉴赏能力也越来越高。经过足够长的时间之后,D再也找不出G的任何缺点了,就这样G如愿以偿的变成了圣人。这种方式就是GAN,GAN的优点是不需要在开始的时候设定圣人,所以这种方式是比较接地气的。

GAN显然是有缺点的:
1、如果G和D不匹配,则G和D都无法学习/成长。例如G本来就很好,D水平很差,则D永远找不到G的缺点。而如果G水平很差,D很厉害,那么无论G如何努力,它都无法达到D的要求,这种情况下G就会无所适从。另外,在训练过程中,G比D进步快太多 (或者相反),就会导致无法继续学习。也就是GAN要求G和D相互匹配。这一点在实际编程中是比较困难的。
2、如果开始的时候G和D都很差,那么它们的进步会很慢。也就是说GAN的训练是比较缓慢的。相对而言,VAE有一个圣人可以作为榜样,所以VAE的训练很快。
3、当G达到圣人的标准,D无法判断G是不是真的圣人。所以D只随便说G是或不是圣人。也就是说D的准确率只有50%。这时的准确率并不比刚开始的时候高。相同道理,G最后的准确率也不比开始的时候高。所以对于GAN来说,准确率以及损失值都是不能正确估量模型的真实情况。
什么是GAN?
有了上面感性的认识之后,我们重新介绍GAN:
GANs的中文翻译是:生成对抗网络。这个网络分为Generator和Discrimator这两部分。简单来说,就是让Generator生成假的数据,Discrimator判别真伪。这两个网络通过不断的相互竞争,最后让Generator生成的数据能够以假乱真。下图可以描绘GAN的过程:

在这里插入图片描述

下面我们来讲一下对抗过程。首先要讲一下生成网络 (Generative Network):

首先给出一个简单的高维的正态分布的噪声向量Z,然后通过生成网络 (这里简单的采用全连接) 生成一张与真实图片一样尺寸的图片,这就是我们所说的假的图片。接下去就是把这张假的图片,还有数据集中真的图片按照一定的比例混合后输送到判别器中。那么我们怎么训练这个生成器呢?就是通过判别器来得到结果,希望提高判别器判别这个结果为真的概率,在这一步我们不更新判别器的参数,只更新生成器的参数。

下面讲一下判别网络 (Discriminator Network):
判别网络中包含一个判别器,它是用来判断一张图片是真的还是假的,所以这是一个二分类问题,如果输入真的图片希望判别器输出的结果是1,输入假的图片希望判别器输出的结果是0。

Generative Network的实现:
生成器的代码非常简单:就是输入一个随机噪音数据z_prior,然后经过3层全连接层,得到28 * 28 = 784维的一张图片。这张图片如果经过变形,就可以变成(28, 28)的图片,尺寸与MNIST数据集中的图片有一样的尺寸。

在这里插入图片描述

Discriminator Network的实现:
判别器的代码也非常简单,就是把一张28 * 28 = 784的图片,经过3层全连接,最后用sigmoid()处理,得到0~1之间的一个数。如果最终的数据在输出值为[0, 0.5]表示为判别器认为这是假图片,(0.5, 1]表示判别器认为这是真的图片。

在这里插入图片描述

GAN的训练:
有了生成器和判别器之后,我们就可以开始考虑训练的过程了。摆在我们面前的问题是先训练哪个网络?对于一个随机噪音数据Z,生成器总是可以生成图片。这时我们就同时有了真假数据 (从MNIST中获得真实数据),这样就可以训练判别器了。而生成器需要通过判别器的结果的反馈来训练,所以开始的时候无法训练生成器。通过这个分析,我们就可以确认:先训练判别器。

为了训练判别器,我们必须先生成生成器的实例G和判别器的实例D。然后定义D的损失函数和优化器。

在这里插入图片描述
接着从MNIST数据集中获取到一个batch的真实图片img。把图片转化成一维数据后转化成Tensor对象。由于我们知道这些图片是真实的,所以我们要做对应数量的真实数据的标签real_label。

在这里插入图片描述
用判别器对这些真实图片进行预测,得到计算结果。然后用这个计算结果与真实数据的标签real_label进行计算损失值。

在这里插入图片描述

接下去就是计算假图片的损失值了。为了得到假的数据,我们就先生成一批随机噪音数据z。然后生成器用噪音数据z生成假图片。这个生成假图片的过程是为训练判别器而提供数据,不是为了训练生成器,所以这里在反向求导时要停止计算梯度。得到假数据之后,与前面真实数据一样处理:依次计算判别器的输出结果,计算判别器对假图片的损失值。

在这里插入图片描述

判别器的总损失值就是两个损失值之和。得到了判别器的损失值之后就可以进行反向传播,参数优化了。

在这里插入图片描述

训练了判别器之后,就可以训练生成器了。前面已经有生成器对象G了,这里就设定它的优化器:

在这里插入图片描述

训练生成器的过程:先获得噪音z,然后用G生成假图片。这里要训练G,所以这里不能用detach()处理。接下去是求判别器的计算结果,然后是计算生成器的损失值。由于生成器希望生成的图片可以以假乱真,所以要跟真实图片的label一起求损失值。在得到损失值之后,就可以进行反向传播和参数优化了。至此,终于完成生成器和判别器的训练了。

在这里插入图片描述

总结:
本节从VAE与GAN的对比开始引入GAN,并提出GAN系统存在的一些问题 (不完全),这些问题在后续改进版本的GAN中已经部分或全部解决了。我们这里只是实现最原始的GAN。我们先分别建立了简单的生成网络和判别网络,然后讨论了应该先训练判别网络的原因。接下去,我们依次编写了训练判别网络和生成网络的代码。总的代码请参考相关算法源文件。
下面,我想请读者玩一个GAN游戏:两个人分别代表G和D。其中D心中想着一个物体 (例如一辆小汽车),然后G开始用一团面粉开始捏。G每改变一点,D都对这个变化进行评判,它是不是更加像D心中想的物体的形状。经过足够多次改进后,G是不是可以把面粉捏的跟D心中的物体很相似。玩过这个游戏之后,相信读者一定会对GAN有更深的体会。
下一节我们要介绍WGAN,它给出了GAN存在的难以训练、训练不稳定等问题的一个通用解决方案。

from torch import nn
from torch.nn import functional as F
from book import GAN_Resource

class GAN_Generator(nn.Module):
    def __init__(self):
        super(GAN_Generator, self).__init__
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值