本文主要讲解有关生成对抗网络(GAN)的相关知识。
一、判别模型和生成模型
机器学习中的模型一般有两种:1. 决策函数 Y=f(X);2. 条件概率分布 P(Y|X)
根据通过学习数据来获取这两种模型的方法,可以分为判别方法和生成方法。判别方法是由数据直接学习决策函数或条件概率分布作为预测模型,即判别模型;而生成模型是由数据学习联合概率分布 P(X,Y),然后由 P(Y|X)=p(X,Y)/P(X) 求出概率分布 P(Y|X) 作为预测模型,即生成模型。
二、生成对抗网络(GAN)
生成对抗网络(GAN)启发自博弈论中的零和博弈(即两人的利益之和为零,一方的所得正式另一方的所失),GAN 模型由生成模型(generative model)和对抗模型(discriminative model)组成。生成模型 G 捕捉样本数据的分布,用服从某一分布的噪声 z 来生成一个类似真是训练数据的样本,追求效果越像真实样本越好;判别模型 D 是一个二分类器,估计一个样本来自训练数据(而非生成数据)的概率。生成器就类似于造假币的人,而判别器就类似于验钞机,生成器的目的就是其造的假币要骗过验钞机。
GAN 的目标函数:
min
G
max
D
V
(
D
,
G
)
=
E
x
−
p
d
a
t
a
(
x
)
[
log
D
(
x
)
]
+
E
z
−
p
z
(
x
)
[
log
(
1
−
D
(
G
(
z
)
)
)
]
\min_G\max_DV(D,G)=E_{x-p_{data}(x)}[\log D(x)]+E_{z-p_z(x)}[\log(1-D(G(z)))]
GminDmaxV(D,G)=Ex−pdata(x)[logD(x)]+Ez−pz(x)[log(1−D(G(z)))]
其中 D(x) 表示真实数据通过判别器 D 的输出结果,而 D(G(z)) 是噪声 z 通过生成器 G 生成的假数据通过判别器 D 的输出结果。所以对于判别器 D 来说,需要最大化目标函数,即通过判别器让真币为真的概率越大越好,而假币为真的概率越小越好。对于生成器 G 来说,需要最小化目标函数,即让假币越真越好。
令 C ( G ) = max D V ( G , D ) C(G)=\max_DV(G,D) C(G)=maxDV(G,D),则 C ( G ) = − log ( 4 ) + 2 ⋅ J S D ( p d a t a ∣ ∣ p g ) C(G)=-\log(4)+2\cdot JSD(p_{data}||p_g) C(G)=−log(4)+2⋅JSD(pdata∣∣pg),其中 JSD 表示 Jensen-Shannon divergence,即 JS 散度。 p d a t a p_{data} pdata 是真实数据的分布,而 p g p_g pg 是生成的假数据的分布。
GAN 存在训练过程不稳定的问题,这一方面是因为 GAN 自身的缺陷,另一方面是因为生成器和判别器的能力不匹配;此外生成器只会生成一两种类别的样本。
GAN 的一个改进是 WGAN。当真实数据的分布和假数据的分布互不重叠时,JS 散度值会趋近于一个常数,其导数接近于0,这就导致了梯度消失。所以重新定义了一种 Wasserstein-1 距离来代替原来的 JS 散度:
W
(
P
r
,
P
g
)
=
inf
γ
−
Π
(
P
r
,
P
g
)
E
(
x
,
y
)
[
∣
∣
x
−
y
∣
∣
]
W(P_r,P_g)=\inf_{\gamma-\Pi(P_r,P_g)}E_{(x,y)}[||x-y||]
W(Pr,Pg)=γ−Π(Pr,Pg)infE(x,y)[∣∣x−y∣∣]
即使
P
r
P_r
Pr 和
P
g
P_g
Pg 互不重叠,wasserstein 距离依旧可以清楚的反应两个分布的距离。
目标函数也变为了:
max
f
w
E
x
−
P
r
[
f
w
(
x
)
]
−
E
z
−
P
z
[
f
w
(
G
(
z
)
)
]
\max_{f_w}E_{x-P_r}[f_w(x)]-E_{z-P_z}[f_w(G(z))]
fwmaxEx−Pr[fw(x)]−Ez−Pz[fw(G(z))]
min G − E z − P z [ f w ( G ( z ) ) ] \min_G-E_{z-P_z}[f_w(G(z))] Gmin−Ez−Pz[fw(G(z))]
WGAN 很好的解决了训练不稳定和模式崩溃的问题。
GAN 只能随机产生一个类别,CGAN 可以指定类别来生成。