GAN属于生成模型,使用生成数据分布 PG 去无限逼近数据的真实分布 Pdata 。衡量两个数据分布的差异有多种度量,例如KL散度等,但是前提是得知道 PG 。GAN利用discriminator巧妙地衡量了 PG,Pdata 的差异性,利用discriminator和generator的不断竞争(minmax)得到了好的generator去生成数据分布 PG 。
背景
很多时候,我们想输入一类数据,然后让机器学习这一类数据的模式,进而产生该类型新的数据。例如:
- 输入唐诗三百首,输出机器写的唐诗
- 输入一堆动漫人物的照片,输出机器生成的动漫人物照片
该问题的核心是原数据有其分布 Pdata ,机器想要学习新的分布 PG 去无限逼近 Pdata 。
一个简单的解决办法是采用异常检测的模型,通过输入大量的正常数据,让机器学习正常数据的内在规律。例如:自编码器模型如下。通过训练数据学习到数据的内在模式code。学习到code后,随机输入新的code便可以产生数据。
对于mnist数据,设code为2维,训练之后输入code得到的图片如下:
但是这种情况下,机器学习到的只是这个数据大概长什么样,而不是数据的真实分布。例如下图的两个7,在人看来都是真的图片7,但是机器却不这么认为。
结构
GAN由generator和discriminator两部分组成:
z -> G -> x' -> D -> 01
x ->
- generator:输入随机的
z
,输出生成的
x′ - discriminator:二分类器,输入生成的 x′ 和真实的 x ,输出01(是否是真的数据)
GAN的训练,也包括generator和discriminator两部分:
discriminator的训练,设generator不变,通过调整discriminator的参数让discriminator尽可能区分开
x,x′ 。
generator的训练,设discriminator不变,通过调整generator的参数让discriminator尽可能区分不开 x,x′ 。
整体来看,generator和discriminator构成了一个网络结构,通过设置loss,保持某一个generator和discriminator参数不变,通过梯度下降更新另外一个的参数即可。
训练
最大似然估计
已知两个分布
Pdata(x)
和
PG(x;θ)
,目标是找到
G
的
采用最大似然估计,有:
也就是说,最大似然 PG(x;θ) 的概率等价于:最小化基于 PG(x;θ) 的编码来编码 Pdata(x) 所需的额外位元数。也就是最小化KL散度。
下面只需要计算出
PG(x;θ)
,一切问题似乎都解决了。事实确实这样,不过对于不同的
G
,
这样来看,传统的最大似然是走不通呢,有没有别的出路呢?
考虑最大似然法真正解决的问题。最大似然就是提供了某种手段,去衡量两个分布 Pdata(x) 和 PG(x;θ) 的相近程度。此路不通另寻他路即可。因此便引出了下文的 V(G,D) 。
V(G,D)取代最大似然估计
V(G,D) 是衡量两个分布 Pdata(x) 和 PG(x;θ) 相近程度的一种手段,其不同于最大似然,是通过一个额外的discriminator识别的好坏做评估的。其核心是:discriminator判别数据是真的数据(1)还是采样的数据(0)。如果两个分布很接近,那么discriminator分辨不清,效果比较差;如果两个分布很远,那么discriminator分辨清,效果比较好。
整个训练策略,是先固定
G
选择
D的训练
这部分解决的是:对于特定的G,如何训练得到更好的D。
首先,对 V(G,D) 做进一步分解:
所以有:
对上述式子求导得到:
每个
D∗(x)
对应的
V(G,D∗)
实际上衡量了特定
G
下面两个分布
将
D∗(x)
代入
V(G,D∗)
,有:
所以:固定G优化D的过程,相当于计算两个分布的距离:
得到两个分布的距离之后,便转化成最小化两个分布的距离的问题,
也就是:
G的训练
固定G优化D得到 D∗ 便得到了两个分布的距离 V(G,D∗) ,固定 D∗ 优化G,采用梯度下降即可。
算法
问题
G的更新优化不一定朝着最小的方向
优化G之后,原来的D对应的就不一定是
maxV(G,D)
最大的
G
了,这样与我们的假设不同。
解决办法是:就像梯度更新的时候迈的步子不能太大;更新G的时候迈的步子也不要太大。
通过抽样估计分布
G中的目标函数
实际训练G的时候,目标函数需要做一些修改,修改的原因是:在刚开始训练的时候,
利用D去评估分布差异
理论上,可以用D去评估分布差异(G的好坏),D越好表明G越差,D越差表明G越好。但是实际中,这样评价的效果不好,不论G的好坏,D都比较好。
可能的原因之一:D太强大了。直观的解决办法是让D变弱一些,但是这样得到的D可否真正的计算JS divergence是个问题。
可能的原因之二:数据的本质是高维空间的manifold,很少有重叠。没有重叠的话,js距离永远都是最大值,不容易学习更新。
通常的解决办法是:给D的输入增加人为的噪声,这样真实数据与人造数据就会有重叠,D不能很好地区分真实数据与人造数据。同时要注意噪声要随时间减少。
mode collapse
mode collapse值的是GAN只学到了数据多个形态中的某一种。
可能的原因是优化式使GAN趋向如此: