这篇文章是李宏毅老师Generative Adversarial Network(GAN)的一篇笔记,对生成对抗网络感兴趣的同学,非常推荐去看李宏毅老师的专题 https://www.youtube.com/watch?v=DQNNMiAP5lw&list=PLJV_el3uVTsMq6JEFPW35BCiOQTsoqwNw
另外一直有在持续更新GAN的相关内容,大家有兴趣的可以继续阅读~
王希玺:生成对抗网络GAN(二):Conditional Generation 条件生成网络的理解与算法流程
Unsupervised Conditional Generation无监督条件生成网络
这篇文章主要是介绍GAN的基本理念和算法流程,可以看做是GAN的入门。只介绍了算法的直观印象,有关GAN的数学原理,可以参考这篇 生成对抗网络(GAN) 背后的数学理论
GAN的基本概念
GAN有两个部分,生成器和判别器:
生成器
生成器的基本概念其实很简单,输入一个向量,通过一个NN,输出一个高维向量(可以是图片,文字...)通常Input向量的每一个维度都代表着一些特征。
判别器
同时呢GAN还有一个部分,叫做“discriminator”(判别器),它的输入是你想产生的东西(其实就是生成器产生的output),比如一张图片,或者一段语音...它的输出是一个标量,这个标量代表的是这个Input的质量如何,这个数字越大,表示这个输入越真实。
生成器和判别器的关系
其实就是生成器生成一个东西,输入到判别器中,然后由判别器来判断这个输入是真实的数据还是机器生成的,如果没有骗过判别器,那么生成器继续进化,输出第二代Output,再输入判别器,判别器同时也在进化,对生成器的output有了更严格的要求。这样生成器和判别器不断进化,他们的关系有点像一个竞争的关系,所以有了“生成对抗网络(adversarial)”的名字的由来。
下面讲一下GAN的操作:
GAN算法流程简述
- 初始化generator和discriminator
- 每一次迭代过程中:
- 固定generator, 只更新discriminator的参数。从你准备的数据集中随机选择一些,再从generator的output中选择一些,现在等于discriminator有两种input。接下来, discriminator的学习目标是, 如果输入是来自于真实数据集,则给高分;如果是generator产生的数据,则给低分,可以把它当做一个回归问题。
- 接下来,固定住discriminator的参数, 更新generator。将一个向量输入generator, 得到一个output, 将output扔进discriminator, 然后会得到一个分数,这一阶段discriminator的参数已经固定住了,generator需要调整自己的参数使得这个output的分数越大越好。
按这个过程听起来好像有两个网络,而实际过程中,generator和discriminator是同一个网络,只不过网络中间的某一层hidden-layer的输出是一个图片(或者语音,取决于你的数据集)。在训练的时候也是固定一部分hidden-layer,调其余的hidden-layer。当然这里的目标是让output越大越好,所以做的不是常规的梯度下降,而是gradient ascent, 当然其实是类似的。
GAN算法-具体操作
上面用通俗的语言解释了GAN的算法流程,现在将其正式化:
初始化
在每次迭代中:
- 从数据集
中sample出m个样本点,这个m也是一个超参数,需要自己去调
- 从一个分布(可以是高斯,正态..., 这个不重要)中sample出m个向量
- 将第2步中的z作为输入,获得m个生成的数据
- 更新discriminator的参数
来最大化, 我们要使得越大越好,那么下式中就要使得越小越好,也就是去压低generator的分数,会发现discriminator其实就是一个二元分类器:
-
-
(也是超参数,需要自己调)
1~4步是在训练discriminator, 通常discriminator的参数可以多更新几次
5. 从一个分布中sample出m个向量
6. 更新generator的参数
5~6步是在训练generator,通常在训练generator的过程中,generator的参数最好不要变化得太大,可以少update几次