对抗网络 GAN背后的理论
对抗网络GAN是由生成器Generator最终生成图片、文本等结构化数据。
生成器能生成结构化数据的原理是什么呢?
简而言之:就是让生成器Generator找到目标图片、文本的信息的概率密度函数。通过概率密度函数 P d a t a ( x ) P_{data}(x) Pdata(x),生成数据。
1. 最大似然估计近似
任何复杂的问题都可以拆解为简单的问题。 在机器学习中最大似然估计就是基本问题。
我们再对抗网络中使用最大似然估计:
- 我们首先获得目标数据的概率密度函数 P d a t a ( x ) P_{data}(x) Pdata(x)
- 我们设定Generator的概率密度函数为 P G ( x ; θ ) P_{G}(x;\theta) PG(x;θ)
- 找到 θ \theta θ 能够让 P G ( x ; θ ) P_{G}(x;\theta) PG(x;θ) 越来越接近P_{data}(x)$
- 举例: 假设 P G ( x ; θ ) P_{G}(x;\theta) PG(x;θ) 属于高斯分布, θ \theta θ 就代表高斯分布的参数均值mean和方差Variance。
具体的做法是:
- 抽取sample x 1 , x 2 . . . x m {x_1,x_2...x_m} x1,x2...xm from P d a t a ( x ) P_{data}(x) Pdata(x)
- 计算最大似然函数:
目的是让似然函数的结果最大,我们就找到 θ \theta θ
计算最大似然值可以推导如下:
注:上式中,有半部分是一个固定的值,求最小值的时候,减去固定的值,对参数
θ
\theta
θ 的结果不影响。(目的是凑出KL距离公式)
实际上划归为计算 P G ( x ; θ ) P_{G}(x;\theta) PG(x;θ) 和 P d a t a ( x ) P_{data}(x) Pdata(x)KL距离的最小值得问题:
2. 生成器Generator
Generator G就是一个神经网络,它定义了生成器的 P G ( x ; θ ) P_{G}(x;\theta) PG(x;θ)
G的目标是:找到
P
G
(
x
;
θ
)
P_{G}(x;\theta)
PG(x;θ) 和
P
d
a
t
a
(
x
)
P_{data}(x)
Pdata(x)之间的差距最小
2. 鉴别器 Discriminator
鉴别器Discriminator D 就是需要鉴别那些数据是来自Generator G的 P G ( x ; θ ) P_{G}(x;\theta) PG(x;θ) ,那些数据是来自真实数据 P d a t a ( x ) P_{data}(x) Pdata(x)。
D的目标是:更可能的能区分真实数据和生成数据,做好一个质量检查员,而且还需要在工作中不断学习。
举例:
我们有数据
鉴别器对于数据的鉴别难度,取决于数据的概率分布的差距:
用公式表示鉴别器的目标(G是固定的):
其中V的值是(G是固定的):
推导如下:
我们需要找到一个最好的鉴别器D
就是最大化:
3. 算法的详细过程
3.1 数学推导
算法的核心是:
这个公式看上去一头雾水,我们慢慢拆解它。
首先,我们去看:
我们需要挑出最好的鉴别器:让V最大。
那么V 我们知道:
我们把它转换为求最大值的普通数学问题(大一或者高三知识就可以求解)
其中a,b 都是固定值,求最大值D,我们推导一下:
求出最优的 D ∗ D^{*} D∗, 我们把它代回得到:
注明:其中最后的结果中有JS 距离。和KL一样,JS距离也是衡量两种概率分布的工具。
求解完D后我们再看下,最小化
m
a
x
V
(
G
,
D
)
maxV(G,D)
maxV(G,D),是什么意思。我们假设存在三个G,G1,G2,G3, 每一个G都有一个
m
a
x
V
(
G
,
D
)
maxV(G,D)
maxV(G,D)。
很显然,算法最终的结果是选择G3。
3.2 算法过程
算法过程看起来比较简单,但是实际操作中遇到很多很问题。GAN是比较难以”驯服“的。
实际操作:
-
给定G,计算 m a x V ( G , D ) maxV(G,D) maxV(G,D)
抽取sample x 1 , x 2 . . . x m {x_1,x_2...x_m} x1,x2...xm from P d a t a ( x ) P_{data}(x) Pdata(x);抽取sample x ‘ 1 , x ’ 2 . . . x ‘ m {x‘_1,x’_2...x‘_m} x‘1,x’2...x‘m from P G ( x ) P_{G}(x) PG(x),计算最大值。
D实际上是我们学过的最简单的二元分类器。
我们需要找到一个最好的D。
- 给定D,找到能让 P d a t a ( x ) P_{data}(x) Pdata(x)和 P G ( x ) P_{G}(x) PG(x)分布距离最小的G。
整体算法过程:
注明:GAN的object 函数很难训练,刚开始的变化比较小。
其中给定D的情况下,V的左半部分是固定值,我们可以不用考虑。
实操中:V可以写作
这样,函数图像变为:
这样的函数,就相对好train许多。
本专栏图片、公式很多来自台湾大学李宏毅老师的深度学习课程,在这里,感谢这些经典课程,向李老师致敬!