现在我们从理论层面进行分析,通过博弈学习的训练方式,生成器G和判别器D分别会达到什么平衡状态。具体地,我们将探索以下两个问题:
- 固定G,D会收敛到什么最优状态 D ∗ D^* D∗?
- 在D达到最优状态 D ∗ D^* D∗后,G会收敛到什么状态?
首先我们通过
x
r
∼
p
r
(
⋅
)
\boldsymbol x_r\sim p_r (\cdot)
xr∼pr(⋅)一维正态分布的例子给出一个直观的解释。如下图所示,黑色虚线曲线代表了真实数据的分布
p
r
(
⋅
)
p_r (\cdot)
pr(⋅),为某正态分布
N
(
μ
,
σ
2
)
\mathcal N(μ,σ^2)
N(μ,σ2),绿色实线代表了生成网络学习到的分布
x
f
∼
p
g
(
⋅
)
\boldsymbol x_f\sim p_g (\cdot)
xf∼pg(⋅),蓝色虚线代表了判别器的决策边界曲线,图(a)、(b)、(c)、(d)分别代表了生成网络的学习轨迹。在初始状态,如下图(a)所示,
p
g
(
⋅
)
p_g (\cdot)
pg(⋅)分布与
p
r
(
⋅
)
p_r (\cdot)
pr(⋅)差异较大,判别器可以很轻松地学习到明确的决策边界,即图(a)中的蓝色虚线,将来自
p
g
(
⋅
)
p_g (\cdot)
pg(⋅)的采样点判定为0,
p
r
(
⋅
)
p_r (\cdot)
pr(⋅)中的采样点判定为1.随着生成网络的分布
p
g
(
⋅
)
p_g (\cdot)
pg(⋅)越来越逼近真是分布
p
r
(
⋅
)
p_r (\cdot)
pr(⋅),判别器越来越困难将真假样本区分开,如下图(b)(c)所示。最后,生成网络学习到的分布
p
g
(
⋅
)
=
p
r
(
⋅
)
\boldsymbol {p_g (\cdot)=p_r (\cdot)}
pg(⋅)=pr(⋅)时,此时从生成网络中采样的样本非常逼真,判别器无法区分,即判定为真假样本的概率均等,如下图(d)所示。
这个例子直观地解释了GAN网络的训练过程。
1. 判别器状态
现在来推导第一个问题。回顾GAN的损失函数:
L
(
G
,
D
)
=
∫
x
p
r
(
x
)
log
(
D
(
x
)
)
+
∫
z
p
z
(
z
)
log
(
1
−
D
(
g
(
z
)
)
)
d
z
=
∫
x
p
r
(
x
)
log
(
D
(
x
)
)
+
p
g
(
x
)
l
o
g
(
1
−
D
(
x
)
)
d
x
\begin{aligned}\mathcal L(G,D)&=\int_x {p_r (\boldsymbol x)} \text{log}(D(\boldsymbol x))+\int_z {p_z (\boldsymbol z)} \text{log}(1-D(g(\boldsymbol z)))d\boldsymbol z\\ &=\int_x {p_r (\boldsymbol x) } \text{log}(D(\boldsymbol x))+p_g (\boldsymbol x)log(1-D(\boldsymbol x))d\boldsymbol x\end{aligned}
L(G,D)=∫xpr(x)log(D(x))+∫zpz(z)log(1−D(g(z)))dz=∫xpr(x)log(D(x))+pg(x)log(1−D(x))dx
对于判别器D,优化的目标是最大化
L
(
G
,
D
)
\mathcal L(G,D)
L(G,D)函数,需要找出函数:
f
θ
=
p
r
(
x
)
log
(
D
(
x
)
)
+
p
g
(
x
)
log
(
1
−
D
(
x
)
)
f_θ=p_r (\boldsymbol x) \text{log}(D(\boldsymbol x))+p_g (\boldsymbol x)\text{log}(1-D(\boldsymbol x))
fθ=pr(x)log(D(x))+pg(x)log(1−D(x))
的最大值,其中
θ
θ
θ为判别器D的网络参数。
我们来考虑
f
θ
f_θ
fθ更通用的函数的最大值情况:
f
(
x
)
=
A
log
x
+
B
log
(
1
−
x
)
f(x)=A \text{log}x+B\text{log}(1-x)
f(x)=Alogx+Blog(1−x)
要求得
f
(
x
)
f(x)
f(x)的最大值。考虑
f
(
x
)
f(x)
f(x)的导数:
d
f
(
x
)
d
x
=
A
1
ln
10
1
x
−
B
1
ln
10
1
1
−
x
=
1
ln
10
(
A
x
−
B
1
−
x
)
=
1
ln
10
A
−
(
A
+
B
)
x
x
(
1
−
x
)
\begin{aligned}\frac{\text{d}f(x)}{\text{d}x} &=A \frac{1}{\text{ln}10} \frac{1}{x}-B \frac{1}{\text{ln}10} \frac{1}{1-x}\\ &=\frac{1}{\text{ln}10} (\frac{A}{x}-\frac{B}{1-x})\\ &=\frac{1}{\text{ln}10} \frac{A-(A+B)x}{x(1-x)}\end{aligned}
dxdf(x)=Aln101x1−Bln1011−x1=ln101(xA−1−xB)=ln101x(1−x)A−(A+B)x
令
d
f
(
x
)
d
x
=
0
\frac{\text{d}f(x)}{\text{d}x}=0
dxdf(x)=0,我们可以求得
f
(
x
)
f(x)
f(x)函数的极值点:
x
=
A
A
+
B
x=\frac{A}{A+B}
x=A+BA
因此,可以得知,
f
θ
f_θ
fθ函数的极值点同样为:
D
θ
=
p
r
(
x
)
p
r
(
x
)
+
p
g
(
x
)
D_θ=\frac{p_r (\boldsymbol x)}{p_r (\boldsymbol x)+p_g (\boldsymbol x)}
Dθ=pr(x)+pg(x)pr(x)
也就是说,判别器网络
D
θ
D_θ
Dθ处于
D
θ
∗
D_{θ^*}
Dθ∗状态时,
f
θ
f_θ
fθ函数取得最大值,
L
(
G
,
D
)
\mathcal L(G,D)
L(G,D)函数也取得最大值。
现在回到最大化
L
(
G
,
D
)
\mathcal L(G,D)
L(G,D)的问题,
L
(
G
,
D
)
\mathcal L(G,D)
L(G,D)的最大值点在:
D
∗
=
A
A
+
B
=
p
r
(
x
)
p
r
(
x
)
+
p
g
(
x
)
D^*=\frac{A}{A+B}=\frac{p_r (\boldsymbol x)}{p_r (\boldsymbol x)+p_g (\boldsymbol x)}
D∗=A+BA=pr(x)+pg(x)pr(x)
时取得,此时也是
D
θ
D_θ
Dθ的最优状态
D
∗
D^*
D∗。
2. 生成器状态
再推导第二个问题之前,我们先介绍以下与KL散度类似的另一个分布距离度量标准:JS散度,它定义为KL散度的组合:
D
K
L
(
p
∣
∣
q
)
=
∫
x
p
(
x
)
log
p
(
x
)
q
(
x
)
d
x
D_{KL} (p||q)=∫_x p(\boldsymbol x) \text{log}\frac{p(\boldsymbol x)}{q(\boldsymbol x)} \text{d}\boldsymbol x
DKL(p∣∣q)=∫xp(x)logq(x)p(x)dx
D
J
S
(
p
∣
∣
q
)
=
1
2
D
K
L
(
p
∣
∣
p
+
q
2
)
+
1
2
D
K
L
(
q
∣
∣
p
+
q
2
)
D_{JS} (p||q)=\frac{1}{2} D_{KL} \Big(p||\frac{p+q}{2}\Big)+\frac{1}{2} D_{KL} \Big(q||\frac{p+q}{2}\Big)
DJS(p∣∣q)=21DKL(p∣∣2p+q)+21DKL(q∣∣2p+q)
JS散度克服了KL散度不对称的缺陷。
当D达到最优状态
D
∗
D^*
D∗时,我们来考虑此时
p
r
p_r
pr和
p
g
p_g
pg的JS散度:
D
J
S
(
p
r
∣
∣
p
g
)
=
1
2
D
K
L
(
p
r
∣
∣
p
r
+
p
g
2
)
+
1
2
D
K
L
(
p
g
∣
∣
p
r
+
p
g
2
)
D_{JS} (p_r ||p_g)=\frac{1}{2} D_{KL} \Big(p_r ||\frac{p_r+p_g}{2}\Big)+\frac{1}{2} D_{KL} \Big(p_g ||\frac{p_r+p_g}{2}\Big)
DJS(pr∣∣pg)=21DKL(pr∣∣2pr+pg)+21DKL(pg∣∣2pr+pg)
根据KL散度的定义展开为:
D
J
S
(
p
r
∣
∣
p
g
)
=
1
2
(
log
2
+
∫
x
p
r
(
x
)
log
p
r
(
x
)
p
r
(
x
)
+
p
g
(
x
)
d
x
)
+
1
2
(
log
2
+
∫
x
p
r
(
x
)
log
p
g
(
x
)
p
r
(
x
)
+
p
g
(
x
)
d
x
)
\begin{aligned}D_{JS} (p_r ||p_g)=\frac{1}{2} \Big(\text{log}2+∫_x p_r (\boldsymbol x) \text{log}\frac{p_r (\boldsymbol x)}{p_r (\boldsymbol x)+p_g (\boldsymbol x)} \text{d}\boldsymbol x\Big)\\ +\frac{1}{2}\Big(\text{log}2+∫_x p_r (\boldsymbol x) \text{log}\frac{p_g (\boldsymbol x)}{p_r (\boldsymbol x)+p_g (\boldsymbol x)} \text{d}\boldsymbol x\Big)\end{aligned}
DJS(pr∣∣pg)=21(log2+∫xpr(x)logpr(x)+pg(x)pr(x)dx)+21(log2+∫xpr(x)logpr(x)+pg(x)pg(x)dx)
合并常数项可得:
D
J
S
(
p
r
∣
∣
p
g
)
=
1
2
(
log
2
+
log
2
)
+
1
2
(
∫
x
p
r
(
x
)
log
p
r
(
x
)
p
r
(
x
)
+
p
g
(
x
)
d
x
+
∫
x
p
r
(
x
)
log
p
g
(
x
)
p
r
(
x
)
+
p
g
(
x
)
d
x
)
\begin{aligned}&D_{JS} (p_r ||p_g)=\frac{1}{2}(\text{log}2+\text{log}2)\\ &+\frac{1}{2} \Big(∫_x p_r (\boldsymbol x) \text{log}\frac{p_r (\boldsymbol x)}{p_r (\boldsymbol x)+p_g (\boldsymbol x)} \text{d}\boldsymbol x+∫_x p_r (\boldsymbol x) \text{log}\frac{p_g (\boldsymbol x)}{p_r (\boldsymbol x)+p_g (\boldsymbol x)} \text{d}\boldsymbol x\Big)\end{aligned}
DJS(pr∣∣pg)=21(log2+log2)+21(∫xpr(x)logpr(x)+pg(x)pr(x)dx+∫xpr(x)logpr(x)+pg(x)pg(x)dx)
即:
D
J
S
(
p
r
∣
∣
p
g
)
=
1
2
(
log
4
)
+
1
2
(
∫
x
p
r
(
x
)
log
p
r
(
x
)
p
r
(
x
)
+
p
g
(
x
)
d
x
+
∫
x
p
r
(
x
)
log
p
g
(
x
)
p
r
(
x
)
+
p
g
(
x
)
d
x
)
\begin{aligned}&D_{JS} (p_r ||p_g)=\frac{1}{2}(\text{log}4)\\ &+\frac{1}{2} \Big(∫_x p_r (\boldsymbol x) \text{log}\frac{p_r (\boldsymbol x)}{p_r (\boldsymbol x)+p_g (\boldsymbol x)} \text{d}\boldsymbol x+∫_x p_r (\boldsymbol x) \text{log}\frac{p_g (\boldsymbol x)}{p_r (\boldsymbol x)+p_g (\boldsymbol x)} \text{d}\boldsymbol x\Big)\end{aligned}
DJS(pr∣∣pg)=21(log4)+21(∫xpr(x)logpr(x)+pg(x)pr(x)dx+∫xpr(x)logpr(x)+pg(x)pg(x)dx)
考虑在判别网络到达
D
∗
D^*
D∗时,此时的损失函数为:
L
(
G
,
D
∗
)
=
∫
x
p
r
(
x
)
log
(
D
∗
(
x
)
)
+
p
g
(
x
)
log
(
1
−
D
∗
(
x
)
)
d
x
=
(
∫
x
p
r
(
x
)
log
p
r
(
x
)
p
r
(
x
)
+
p
g
(
x
)
d
x
+
∫
x
p
r
(
x
)
log
p
g
(
x
)
p
r
(
x
)
+
p
g
(
x
)
d
x
)
\begin{aligned}\mathcal L(G,D^* )&=∫_x p_r (\boldsymbol x) \text{log}\big(D^* (\boldsymbol x)\big)+p_g (\boldsymbol x)\text{log}\big(1-D^* (\boldsymbol x)\big)\text{d}\boldsymbol x\\ &=\Big(∫_x p_r (\boldsymbol x) \text{log}\frac{p_r (\boldsymbol x)}{p_r (\boldsymbol x)+p_g (\boldsymbol x)} \text{d}\boldsymbol x+∫_x p_r (\boldsymbol x) \text{log}\frac{p_g (\boldsymbol x)}{p_r (\boldsymbol x)+p_g (\boldsymbol x)} \text{d}\boldsymbol x\Big)\end{aligned}
L(G,D∗)=∫xpr(x)log(D∗(x))+pg(x)log(1−D∗(x))dx=(∫xpr(x)logpr(x)+pg(x)pr(x)dx+∫xpr(x)logpr(x)+pg(x)pg(x)dx)
因此在判别网络到达
D
∗
D^*
D∗时,
D
J
S
(
p
r
∣
∣
p
g
)
D_{JS} (p_r ||p_g)
DJS(pr∣∣pg)与
L
(
G
,
D
∗
)
\mathcal L(G,D^* )
L(G,D∗)满足关系:
D
J
S
(
p
r
∣
∣
p
g
)
=
1
2
(
log
4
+
L
(
G
,
D
∗
)
)
D_{JS} (p_r ||p_g)=\frac{1}{2}\big(\text{log}4+\mathcal L(G,D^* )\big)
DJS(pr∣∣pg)=21(log4+L(G,D∗))
即:
L
(
G
,
D
∗
)
=
2
D
J
S
(
p
r
∣
∣
p
g
)
−
2
log
2
\mathcal L(G,D^* )=2D_{JS} (p_r ||p_g)-2 \text{log}2
L(G,D∗)=2DJS(pr∣∣pg)−2log2
对于生成网络G而言,训练目标是
min
G
L
(
G
,
D
)
\underset{G}{\text{min}}\mathcal L(G,D)
GminL(G,D),考虑到JS散度具有性质:
D
J
S
(
p
r
∣
∣
p
g
)
≥
0
D_{JS} (p_r ||p_g)≥0
DJS(pr∣∣pg)≥0
因此
L
(
G
,
D
∗
)
\mathcal L(G,D^* )
L(G,D∗)取得最小值仅在
D
J
S
(
p
r
∣
∣
p
g
)
=
0
D_{JS} (p_r ||p_g)=0
DJS(pr∣∣pg)=0时(此时
p
g
=
p
r
p_g=p_r
pg=pr),
L
(
G
,
D
∗
)
\mathcal L(G,D^* )
L(G,D∗)取得最小值:
L
(
G
∗
,
D
∗
)
=
−
2
log
2
L(G^*,D^* )=-2\text{log}2
L(G∗,D∗)=−2log2
此时生成网络
G
∗
G^*
G∗的状态是:
p
g
=
p
r
p_g=p_r
pg=pr
即
G
∗
G^*
G∗的学到的分布
p
g
p_g
pg与真是分布
p
r
p_r
pr一致,网络达到平衡点,此时:
D
∗
=
p
r
(
x
)
p
r
(
x
)
+
p
g
(
x
)
=
0.5
D^*=\frac{p_r (\boldsymbol x)}{p_r (\boldsymbol x)+p_g (\boldsymbol x)}=0.5
D∗=pr(x)+pg(x)pr(x)=0.5
3. 纳什均衡点
通过上面的推导,我们可以总结出生成网络G最终将收敛到真是分布,即:
p
g
=
p
r
p_g=p_r
pg=pr
此时生成的样本与真实样本来自统一分部,真假难辨,在判别器中均由相同的概率判定为真或假,即
D
(
⋅
)
=
0.5
D(\cdot)=0.5
D(⋅)=0.5
此时损失函数为
L
(
G
∗
,
D
∗
)
=
−
2
log
2
\mathcal L(G^*,D^* )=-2\text{log}2
L(G∗,D∗)=−2log2