一、简介
最近在研究深度学习相关的知识,看了CNN、RNN、DNN等经典的神经网络,然后研究了一下生成模型,也就是今天要讲的生成对抗网络(GAN),打算出一个系列,毕竟关于生成对抗网络的论文太多了,github上有整理,有兴趣的小伙伴可以自己看看原论文顺便跑一下代码,真的很有意思。
GAN自诞生起一直颇受赞誉,后期也衍生了很多变种,大多数只是在损失函数这一块做文章,不过也有颇多成效,GAN也有它自身的缺点,不过在WGAN以后得到了很大的改善。GAN的提出者是Goodfellow, 据说他是在喝醉了以后偶然想出来的,唉,这大概也是人与人之间的差距吧。
GAN本质上是一个minimax游戏,它会同时训练两个神经网络G和D。G负责生成样本;D负责判断这个样本是真实的样本还是G生成的样本。G和D是分开训练的,损失函数如下:
(不是很会使用CSDN的公式编辑器,所以暂时截图代替)
D(x)表示x来自于真实数据的概率。z为随机噪音,为生成器提供泛化能力。
从判别器D的角度看,它要将生成的图片判别为fake(D(x)=0), 真实的训练图片判断为real (D(x)=1);从生成器的角度看,它要让生成的图片尽可能真实,对应于上式右边第二项。
二、原理
GAN的网络结构图如下
(这是到网上找到的图片,如果有侵权请告知我)
生成器和判别器是分开训练的,训练生成器时,判别器参数固定,整个网络只训练生成器;训练判别器时是一样的。整个网络存在最优解,此时,D(x=fake)=1/2,说明此时判别器无法区分出生成图片和训练图片,生成器的目的已经达到。
以上便是GAN的训练算法,判别器训练k轮,生成器训练1轮,交替进行训练。
三、效果
我利用了github上的代码,代码的结构与上述说明一致。总共训练了30000轮。我们采用了mnist数据集。
初始生成图片如下:
就是一团随机噪音而已。
训练1000轮后:
开始有了数字的模样。
训练5000轮以后:
已经很不错了。训练30000次的效果和这差不多。
原始GAN还有如下缺点:
1.难以训练到收敛。这个问题从训练过程中可以看出来,最后的生成器误差大概在0.6左右,之后会在此上下波动,难以达到最佳的0.5。这个和生成器的损失函数有一定关系,后续的WGAN专门针对这个问题进行了改进。
2.图片类别无法控制。比如对于mnist数据集,能够生成0-9的数字图片,但是我如果只想生成数字0的图片呢?原始GAN是无法做到的,这个在后面的GAN版本里有改进,措施是在训练的同时加入类别标签。
3.生成器容易崩溃。这个问题的原因是生成器可能找到某种trick,用其来欺骗判别器。此时生成的图片不再具有参考意义。
四、总结
以上就是GAN的基本内容了。虽然GAN还有不少问题,但还是具有重大的指导意义,后来者也提出了一系列的改进方案,由此产生了一系列的衍生版本。
如果本文有什么问题,可以与我探讨。