本文仍然基于李宏毅老师的精彩讲授随笔记录
对于Generation问题,它要做的事情就是生成的样本的分布尽可能逼近真实样本分布,即如下图:
因此,传统的方法就是从真实样本中取样,然后使用MLE让生成样本的分布
去尽可能逼近它,我们使用MLE去推算下看看出现什么。
其中约等号成立是因为样本是从中取样的,因此可以看做是对
的Expectation。然后因为目标是求
,因此后面减去一项与
无关的东西不影响结果,最后我们得出这样的结论:
极大化MLE做的事情就是在极小化与
的KL Divergence,但由于传统的generation做法都是使用像高斯混合模型来作为
,导致模型表达能力有限,且参数难以确定,因而效果不是很好。而GAN则是使用一个Generator来替代了这个
的效果。
我们可以将Generator看成是一个概率分布的生成器,他是一个network,因此具有很强的表达能力,也就是它将我们随机输入的噪声分布转化成,来达到生成样本的效果。而Discriminator的效果则是在计算二者的Divergence(这里其实是JS-Divergence,后面会推理出来),因此二者相互作用,使得generator生成的概率分布越来越逼近真实分布。
那么为什么Discriminator可以实现衡量两个分布的Divergence的效果呢?
我们可以看到,在最早的GAN中,Discriminator的Objective Function是:
而
最大化V其实就是说我们希望当X是真实样本时,我们尽量给他1,当X时generator生成的样本时,我们尽量给他0这样。
而为什么就等价于在衡量二者的Divergence呢?下面我们来推一下。
根据上面的推导,我们可以看到,当我们训练Discriminator直到收敛时,获得D*,这时候我们使用和
通过D*计算出来的objective function其实就是在衡量二者的JS-Divergence。因此每次我们训练完D*之后,我们就可以frozen住D*,然后训练Generator,希望Generator生成的
与真实样本
通过D*后的objective function尽量小,也就是尽量减少二者的JS-Divergence。
总结一下:训练Discriminative的目的是因为收敛获得的D*实际上就是在衡量JS-Divergence,而固定Discriminative训练Generator的目的就是为了让Generator产生的能尽可能地与
的JS-Divergence最小。这就是原始GAN的思想。
因此最后Generator的目标即写为:
比如上面的三幅图,那么我们最终选择的G*就应该是G3.
不过这里是有一个小问题的,那就是:
当我们获得,然后去更新G0得到G1,这个时候你的V(G1,
)可能就不是在衡量JS-Divergence了,最坏的情况甚至有可能你更新完之后出现右图那个情况,就是更新之后新的G1与Pdata的JS-Divergence反而变大了,因为更新到G1后,这个时候
并不是maxV的解了,也就不是在衡量JS-D了,所以有可能会这样。解决的方法就是,我们每次更新G的时候,不要更新太多,让
,也就是右图两个曲线是非常接近的,这样我们才能认为说,你在更G,确实是在减少二者的JS-Divergence。所以GAN的算法中,对Discriminator的更新是多次的,为了尽可能找到max的D*,而更新Generator的时候,我们只更新一次,就是因为这个假设。
下面我们看看实做上,是如何train Discriminator的:
我们理论上是要maximize V,但是由于我们没办法真的获得分布的Expectation,因此遵循大数定律,我们从与
中sample m笔data出来计算均值,但是其实,计算均值这个事情是和我们train一个binary classifier使用logistic loss function的效果是等价的,所以我们只需要将Discriminator设置为一个普通的二分类classifier,使用cross-entropy loss进行训练就可以了。
最后给出算法流程
简单解释下,在Learning D的时候,前面说过Update D的公式其实就是在train一个binary classifier,因此实做起来很容易,然后为了尽可能让D接近D*,我们每次迭代训练D k times。然后Learning G的时候,我们同样只需要让生成的样本通过D尽可能被判别为1,也就是减少JS-Dviergence就可以了,这里之所以把前面一般划掉,是因为前面一半只和D有关,在训练G的时候我们是固定D的,因此只需要训练后面一半式子就够了,也就是只需要让G生成的样本尽可能被误判就行。之前说过,为了让,因此这里我们每次迭代更新G,都只更新一次。
先到这里,李宏毅老师的视频价值真的很大,如果直接看paper,有很多地方会看的不太懂。看了视频再看paper,加深理解,感谢李宏毅老师!