生成对抗网络(GAN)的原理与公式推导

论文地址

https://arxiv.org/abs/1406.2661

希望路过的学霸们可以看看最后一句话

一、基本概念

生成对抗网络(Generative Adversarial Networks)是一种无监督深度学习模型,用来通过计算机生成数据,由Ian J. Goodfellow等人于2014年提出。模型通过框架中两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。生成对抗网络被认为是当前最具前景、最具活跃度的模型之一,目前主要应用于样本数据生成、图像生成、图像修复、图像转换、文本生成等方向。

二、GAN的原理

1、Generator(生成器)

随机输入高斯噪声(注意这个噪声一定要是来自简单的分布,容易采样的,例如normal distribution),生成图片(概率分布 distribution)。
我们将生成的图片的数据分布叫做Pg真实图片的数据分布叫做Pdata
在这里插入图片描述
我们的目的:通过Generator生成高质量的图片。也就是说,通过Generator生成的数据分布,与真实图片的数据分布,越接近越好
在这里插入图片描述

2、Discriminator(判别器)

输入图片,给图片打分,输出为一个Scalar(数字)。
如果图片来自真实图片,给高分;来自生成的图片,给低分
在这里插入图片描述

3、强强联合

在GAN中,我们就是在不断的训练G和D,让G生成的图片的分布尽可能逼近真实图片的分布。
在训练过程中:
D学着去分辨真实的图片和G生成的图片,企图给真实的图片打高分,给G生成的图片打低分。
G学着去骗过D,企图让D分辨不出来图片是来自真实的数据集还是G生成的。达到“以假乱真”的效果。

总结:G其实就是在学会做高仿品。D也在不断变强,增强自己辨真假的能力。这就是“对抗”,毕竟,只有足够强大的对手才能逼出更强大的自己。

小例子:
这个例子来自李宏毅老师的机器学习课程,链接在这(b站):
李宏毅2021/2022春机器学习课程—生成式对抗网络(GAN)
20分20秒处
在这里插入图片描述

4、算法

在每轮的迭代过程中:
Step1:固定Generator,更新Discriminator;

在这里插入图片描述
Step2:固定Discriminator,更新Generator;

在这里插入图片描述

整个迭代步骤:
在这里插入图片描述

终极目标:通过G生成D难以分辨的图片,D打分为0.5,因为它不知道是真是假。

三、公式推导

刚才说到,我们希望通过G生成的分布与真实图片的分布,越接近越好。
那么,如何衡量两个分布之间的相似程度(或者说差异性)呢
巧了,在数学中有一些东西,可以做到这件事情,它叫做KL散度(JS散度等等也可以,不止这一种)。

1、KL散度、JS散度

KL散度(KL Divergence):也称相对熵、KL距离。对于两个概率分布P和Q之间的差异性(也可以简单理解成相似性),二者越相似,KL散度越小

关于KL散度具体的内容,推导,可以看看b站王木头讲的,个人觉得讲的很清楚了,链接如下(其实我觉得弄懂了好像也没啥用,可能是我太菜了):
“交叉熵”如何做损失函数?打包理解“信息量”、“比特”、“熵”、“KL散度”、“交叉熵”

这里,只列出KL散度的计算公式(离散随机变量):
在这里插入图片描述

JS散度(JS Divergence):JS散度是KL散度的一种变体,与KL散度相似,P和Q越相似,JS散度越小。
在这里插入图片描述
与KL散度相比,JS散度具有对称性,也就是说,P和Q可以交换位置,而KL散度不具有这个特点。

2、求解Generator

其实现在,想必大家也猜到了G的优化公式。那就是,最小化Pg和Pdata的KL散度
在这里插入图片描述
其实这个divergence就是Generator中的loss function。
Generator也是一个神经网络,我们就是在找出一组权值(weight)和阈值(bias),使得Pg和Pdata之间的KL散度,越小越好。
但是,通过KL散度的公式,我们会发现,要求解KL散度,必须要知道Pg和Pdata的分布,你要事先知道他们俩长什么样,所以,我们很难直接算出Pg和Pdata之间的KL散度。那么,该如何得到KL散度呢?
这就是GAN神奇的地方,它告诉你说,你不需要知道Pg和Pdata的公式长什么样子,就可以计算KL散度,这就要依靠Discriminator的力量

3、求解Discriminator

GAN提出的对于D的优化是这样的:
在这里插入图片描述

我们希望给来自Pdata的样本高分,给来自Pg的样本低分,所以就是要最大化这个目标方程。
GAN神奇的对方就在于,这个最大化的目标方程,是和JS散度(KL散度也可以)有关的。

接下来就来求解一下最优的D,并求解这个V(G,D)

在这里插入图片描述

注意:上文说过,这里是固定G,求解D

在这里插入图片描述

D* 就是使得目标方程 V(G,D) 最大的D
我们把 D* 带入到 V(G,D)

在这里插入图片描述
在这里插入图片描述

可以看到,求解得到的 max V(G,D) 与JS散度有关,那么,求解 max V(G,D) ,其实就是在求Pg和Pdata的JS散度,G的优化函数就可以写成如下形式:

在这里插入图片描述
也就得到了,论文中提到的那个方程:
在这里插入图片描述

4、整个训练的具体步骤

在这里插入图片描述
注:
(1)这里用采样后求均值的方法代替分布的期望
(2)更新G的时候,不能更新太多。我们更新G的时候,是固定了D,那么如果G更新过多的话, V(G,D) 可能会发生很大变化,以至于当前的D可能已经不是使 V(G,D) 达到最大值的D,那么这时候我们通过梯度下降减小的也就不是JS散度了。所以,这里实际上是假设了,梯度下降的每一步更新后,使得 V(G,D) 达到最大的D是基本没变的。如下图所示:
在这里插入图片描述

四、存在的缺点与问题

(1)没有显示地表达Pg
(2)D必须与G同步训练,且G不能更新太多。最终需要达到纳什平衡,但是有时候是做不到的,所以训练起来有时候是不稳定的,且生成器和图像质量之间缺乏相关性。
(3)模式崩溃问题。就是说G会偷懒,它发现有一种方法可以无限次骗过D,那它就会得寸进尺,每次都用这种方法来骗D,换到我们这个场景里面来说就是,生成的图片比较单一,缺乏多样性。就像下图所示:
在这里插入图片描述
生成的样本全部聚集在左边的峰下,这时虽然生成样本的质量比较高,但是生成器完全没有捕捉到右边的峰的模式。(如果使用多种猫的图像训练GAN,最终GAN只能产生逼真的英短,而无法产生其他品种)。

五、GAN的改进

1、Conditional GAN

论文地址:https://arxiv.org/abs/1411.1784
在原始的GAN基础上加了条件y,分别加到G和D中。
在这里插入图片描述

2、Deep Convolutional GAN

论文地址:https://arxiv.org/abs/1511.06434
将CNN加入到GAN中。

3、DRAGAN(On Convergence and Stability of GANs)

论文地址:https://arxiv.org/abs/1705.07215
用梯度惩罚方案解决了模式崩溃问题。

4、Cycle Gan

论文地址:https://arxiv.org/abs/1703.10593
图像转换。
在这里插入图片描述
还有很多GAN,我也没看几个,就把最原始的GAN仔细的从头到尾看了下,希望这篇可以对大家有所帮助。

接下来准备看Diffusion Model,希望同学们可以多多帮助,分享下学习经验,paper,视频等等,谢谢了!研一菜鸟一枚。

  • 9
    点赞
  • 52
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值