深度学习之生成对抗网络(5)纳什均衡

深度学习之生成对抗网络(5)纳什均衡


 现在我们从理论层面进行分析,通过博弈学习的训练方式,生成器G和判别器D分别会达到什么平衡状态。具体地,我们将探索以下两个问题:

  • 固定G,D会收敛到什么最优状态 D ∗ D^* D
  • 在D达到最优状态 D ∗ D^* D后,G会收敛到什么状态?

首先我们通过 x r ∼ p r ( ⋅ ) \boldsymbol x_r\sim p_r (\cdot) xrpr()一维正态分布的例子给出一个直观的解释。如下图所示,黑色虚线曲线代表了真实数据的分布 p r ( ⋅ ) p_r (\cdot) pr(),为某正态分布 N ( μ , σ 2 ) \mathcal N(μ,σ^2) N(μ,σ2),绿色实线代表了生成网络学习到的分布 x f ∼ p g ( ⋅ ) \boldsymbol x_f\sim p_g (\cdot) xfpg(),蓝色虚线代表了判别器的决策边界曲线,图(a)、(b)、(c)、(d)分别代表了生成网络的学习轨迹。在初始状态,如下图(a)所示, p g ( ⋅ ) p_g (\cdot) pg()分布与 p r ( ⋅ ) p_r (\cdot) pr()差异较大,判别器可以很轻松地学习到明确的决策边界,即图(a)中的蓝色虚线,将来自 p g ( ⋅ ) p_g (\cdot) pg()的采样点判定为0, p r ( ⋅ ) p_r (\cdot) pr()中的采样点判定为1.随着生成网络的分布 p g ( ⋅ ) p_g (\cdot) pg()越来越逼近真是分布 p r ( ⋅ ) p_r (\cdot) pr(),判别器越来越困难将真假样本区分开,如下图(b)(c)所示。最后,生成网络学习到的分布 p g ( ⋅ ) = p r ( ⋅ ) \boldsymbol {p_g (\cdot)=p_r (\cdot)} pg()=pr()时,此时从生成网络中采样的样本非常逼真,判别器无法区分,即判定为真假样本的概率均等,如下图(d)所示。

 这个例子直观地解释了GAN网络的训练过程。

纳什均衡点

1. 判别器状态

 现在来推导第一个问题。回顾GAN的损失函数:
L ( G , D ) = ∫ x p r ( x ) log ⁡ ( D ( x ) ) + ∫ z p z ( z ) log ⁡ ( 1 − D ( g ( z ) ) ) d z = ∫ x p r ( x ) log ⁡ ( D ( x ) ) + p g ( x ) l o g ⁡ ( 1 − D ( x ) ) d x \begin{aligned}\mathcal L(G,D)&=\int_x {p_r (\boldsymbol x)} \text{log}⁡(D(\boldsymbol x))+\int_z {p_z (\boldsymbol z)} \text{log}⁡(1-D(g(\boldsymbol z)))d\boldsymbol z\\ &=\int_x {p_r (\boldsymbol x) } \text{log}⁡(D(\boldsymbol x))+p_g (\boldsymbol x)log⁡(1-D(\boldsymbol x))d\boldsymbol x\end{aligned} L(G,D)=xpr(x)log(D(x))+zpz(z)log(1D(g(z)))dz=xpr(x)log(D(x))+pg(x)log(1D(x))dx
对于判别器D,优化的目标是最大化 L ( G , D ) \mathcal L(G,D) L(G,D)函数,需要找出函数:
f θ = p r ( x ) log⁡ ( D ( x ) ) + p g ( x ) log ⁡ ( 1 − D ( x ) ) f_θ=p_r (\boldsymbol x) \text{log⁡}(D(\boldsymbol x))+p_g (\boldsymbol x)\text{log}⁡(1-D(\boldsymbol x)) fθ=pr(x)log⁡(D(x))+pg(x)log(1D(x))
的最大值,其中 θ θ θ为判别器D的网络参数。

 我们来考虑 f θ f_θ fθ更通用的函数的最大值情况:
f ( x ) = A log x + B log ⁡ ( 1 − x ) f(x)=A \text{log}x+B\text{log}⁡(1-x) f(x)=Alogx+Blog(1x)
要求得 f ( x ) f(x) f(x)的最大值。考虑 f ( x ) f(x) f(x)的导数:
d f ( x ) d x = A 1 ln ⁡ 10 1 x − B 1 ln ⁡ ⁡ 10 1 1 − x = 1 ln ⁡ 10 ( A x − B 1 − x ) = 1 ln ⁡ 10 A − ( A + B ) x x ( 1 − x ) \begin{aligned}\frac{\text{d}f(x)}{\text{d}x} &=A \frac{1}{\text{ln}⁡10} \frac{1}{x}-B \frac{1}{\text{ln}⁡⁡10} \frac{1}{1-x}\\ &=\frac{1}{\text{ln}⁡10} (\frac{A}{x}-\frac{B}{1-x})\\ &=\frac{1}{\text{ln}⁡10} \frac{A-(A+B)x}{x(1-x)}\end{aligned} dxdf(x)=Aln101x1Bln1011x1=ln101(xA1xB)=ln101x(1x)A(A+B)x
d f ( x ) d x = 0 \frac{\text{d}f(x)}{\text{d}x}=0 dxdf(x)=0,我们可以求得 f ( x ) f(x) f(x)函数的极值点:
x = A A + B x=\frac{A}{A+B} x=A+BA
因此,可以得知, f θ f_θ fθ函数的极值点同样为:
D θ = p r ( x ) p r ( x ) + p g ( x ) D_θ=\frac{p_r (\boldsymbol x)}{p_r (\boldsymbol x)+p_g (\boldsymbol x)} Dθ=pr(x)+pg(x)pr(x)
也就是说,判别器网络 D θ D_θ Dθ处于 D θ ∗ D_{θ^*} Dθ状态时, f θ f_θ fθ函数取得最大值, L ( G , D ) \mathcal L(G,D) L(G,D)函数也取得最大值。

 现在回到最大化 L ( G , D ) \mathcal L(G,D) L(G,D)的问题, L ( G , D ) \mathcal L(G,D) L(G,D)的最大值点在:
D ∗ = A A + B = p r ( x ) p r ( x ) + p g ( x ) D^*=\frac{A}{A+B}=\frac{p_r (\boldsymbol x)}{p_r (\boldsymbol x)+p_g (\boldsymbol x)} D=A+BA=pr(x)+pg(x)pr(x)
时取得,此时也是 D θ D_θ Dθ的最优状态 D ∗ D^* D


2. 生成器状态

 再推导第二个问题之前,我们先介绍以下与KL散度类似的另一个分布距离度量标准:JS散度,它定义为KL散度的组合:
D K L ( p ∣ ∣ q ) = ∫ x p ( x ) log⁡ p ( x ) q ( x ) d x D_{KL} (p||q)=∫_x p(\boldsymbol x) \text{log⁡}\frac{p(\boldsymbol x)}{q(\boldsymbol x)} \text{d}\boldsymbol x DKL(pq)=xp(x)log⁡q(x)p(x)dx
D J S ( p ∣ ∣ q ) = 1 2 D K L ( p ∣ ∣ p + q 2 ) + 1 2 D K L ( q ∣ ∣ p + q 2 ) D_{JS} (p||q)=\frac{1}{2} D_{KL} \Big(p||\frac{p+q}{2}\Big)+\frac{1}{2} D_{KL} \Big(q||\frac{p+q}{2}\Big) DJS(pq)=21DKL(p2p+q)+21DKL(q2p+q)
JS散度克服了KL散度不对称的缺陷。

 当D达到最优状态 D ∗ D^* D时,我们来考虑此时 p r p_r pr p g p_g pg的JS散度:
D J S ( p r ∣ ∣ p g ) = 1 2 D K L ( p r ∣ ∣ p r + p g 2 ) + 1 2 D K L ( p g ∣ ∣ p r + p g 2 ) D_{JS} (p_r ||p_g)=\frac{1}{2} D_{KL} \Big(p_r ||\frac{p_r+p_g}{2}\Big)+\frac{1}{2} D_{KL} \Big(p_g ||\frac{p_r+p_g}{2}\Big) DJS(prpg)=21DKL(pr2pr+pg)+21DKL(pg2pr+pg)
根据KL散度的定义展开为:
D J S ( p r ∣ ∣ p g ) = 1 2 ( log ⁡ 2 + ∫ x p r ( x ) log ⁡ p r ( x ) p r ( x ) + p g ( x ) d x ) + 1 2 ( log ⁡ 2 + ∫ x p r ( x ) log p g ( x ) p r ( x ) + p g ( x ) d x ) \begin{aligned}D_{JS} (p_r ||p_g)=\frac{1}{2} \Big(\text{log}⁡2+∫_x p_r (\boldsymbol x) \text{log}\frac{⁡p_r (\boldsymbol x)}{p_r (\boldsymbol x)+p_g (\boldsymbol x)} \text{d}\boldsymbol x\Big)\\ +\frac{1}{2}\Big(\text{log}⁡2+∫_x p_r (\boldsymbol x) \text{log}\frac{p_g (\boldsymbol x)}{p_r (\boldsymbol x)+p_g (\boldsymbol x)} \text{d}\boldsymbol x\Big)\end{aligned} DJS(prpg)=21(log2+xpr(x)logpr(x)+pg(x)pr(x)dx)+21(log2+xpr(x)logpr(x)+pg(x)pg(x)dx)
合并常数项可得:
D J S ( p r ∣ ∣ p g ) = 1 2 ( log ⁡ 2 + log ⁡ 2 ) + 1 2 ( ∫ x p r ( x ) log p r ( x ) p r ( x ) + p g ( x ) d x + ∫ x p r ( x ) log p g ( x ) p r ( x ) + p g ( x ) d x ) \begin{aligned}&D_{JS} (p_r ||p_g)=\frac{1}{2}(\text{log}⁡2+\text{log}⁡2)\\ &+\frac{1}{2} \Big(∫_x p_r (\boldsymbol x) \text{log}\frac{p_r (\boldsymbol x)}{p_r (\boldsymbol x)+p_g (\boldsymbol x)} \text{d}\boldsymbol x+∫_x p_r (\boldsymbol x) \text{log}\frac{p_g (\boldsymbol x)}{p_r (\boldsymbol x)+p_g (\boldsymbol x)} \text{d}\boldsymbol x\Big)\end{aligned} DJS(prpg)=21(log2+log2)+21(xpr(x)logpr(x)+pg(x)pr(x)dx+xpr(x)logpr(x)+pg(x)pg(x)dx)
即:
D J S ( p r ∣ ∣ p g ) = 1 2 ( log ⁡ ⁡ 4 ) + 1 2 ( ∫ x p r ( x ) log p r ( x ) p r ( x ) + p g ( x ) d x + ∫ x p r ( x ) log p g ( x ) p r ( x ) + p g ( x ) d x ) \begin{aligned}&D_{JS} (p_r ||p_g)=\frac{1}{2}(\text{log}⁡⁡4)\\ &+\frac{1}{2} \Big(∫_x p_r (\boldsymbol x) \text{log}\frac{p_r (\boldsymbol x)}{p_r (\boldsymbol x)+p_g (\boldsymbol x)} \text{d}\boldsymbol x+∫_x p_r (\boldsymbol x) \text{log}\frac{p_g (\boldsymbol x)}{p_r (\boldsymbol x)+p_g (\boldsymbol x)} \text{d}\boldsymbol x\Big)\end{aligned} DJS(prpg)=21(log4)+21(xpr(x)logpr(x)+pg(x)pr(x)dx+xpr(x)logpr(x)+pg(x)pg(x)dx)
考虑在判别网络到达 D ∗ D^* D时,此时的损失函数为:
L ( G , D ∗ ) = ∫ x p r ( x ) log ( D ∗ ( x ) ) + p g ( x ) log ⁡ ( 1 − D ∗ ( x ) ) d x = ( ∫ x p r ( x ) log p r ( x ) p r ( x ) + p g ( x ) d x + ∫ x p r ( x ) log p g ( x ) p r ( x ) + p g ( x ) d x ) \begin{aligned}\mathcal L(G,D^* )&=∫_x p_r (\boldsymbol x) \text{log}\big(D^* (\boldsymbol x)\big)+p_g (\boldsymbol x)\text{log}⁡\big(1-D^* (\boldsymbol x)\big)\text{d}\boldsymbol x\\ &=\Big(∫_x p_r (\boldsymbol x) \text{log}\frac{p_r (\boldsymbol x)}{p_r (\boldsymbol x)+p_g (\boldsymbol x)} \text{d}\boldsymbol x+∫_x p_r (\boldsymbol x) \text{log}\frac{p_g (\boldsymbol x)}{p_r (\boldsymbol x)+p_g (\boldsymbol x)} \text{d}\boldsymbol x\Big)\end{aligned} L(G,D)=xpr(x)log(D(x))+pg(x)log(1D(x))dx=(xpr(x)logpr(x)+pg(x)pr(x)dx+xpr(x)logpr(x)+pg(x)pg(x)dx)
因此在判别网络到达 D ∗ D^* D时, D J S ( p r ∣ ∣ p g ) D_{JS} (p_r ||p_g) DJS(prpg) L ( G , D ∗ ) \mathcal L(G,D^* ) L(G,D)满足关系:
D J S ( p r ∣ ∣ p g ) = 1 2 ( log ⁡ 4 + L ( G , D ∗ ) ) D_{JS} (p_r ||p_g)=\frac{1}{2}\big(\text{log}⁡4+\mathcal L(G,D^* )\big) DJS(prpg)=21(log4+L(G,D))
即:
L ( G , D ∗ ) = 2 D J S ( p r ∣ ∣ p g ) − 2 log ⁡ 2 \mathcal L(G,D^* )=2D_{JS} (p_r ||p_g)-2 \text{log}⁡2 L(G,D)=2DJS(prpg)2log2
 对于生成网络G而言,训练目标是 min G L ( G , D ) \underset{G}{\text{min}}\mathcal L(G,D) GminL(G,D),考虑到JS散度具有性质:
D J S ( p r ∣ ∣ p g ) ≥ 0 D_{JS} (p_r ||p_g)≥0 DJS(prpg)0
因此 L ( G , D ∗ ) \mathcal L(G,D^* ) L(G,D)取得最小值仅在 D J S ( p r ∣ ∣ p g ) = 0 D_{JS} (p_r ||p_g)=0 DJS(prpg)=0时(此时 p g = p r p_g=p_r pg=pr), L ( G , D ∗ ) \mathcal L(G,D^* ) L(G,D)取得最小值:
L ( G ∗ , D ∗ ) = − 2 log ⁡ 2 L(G^*,D^* )=-2\text{log}⁡2 L(G,D)=2log2
此时生成网络 G ∗ G^* G的状态是:
p g = p r p_g=p_r pg=pr
G ∗ G^* G的学到的分布 p g p_g pg与真是分布 p r p_r pr一致,网络达到平衡点,此时:
D ∗ = p r ( x ) p r ( x ) + p g ( x ) = 0.5 D^*=\frac{p_r (\boldsymbol x)}{p_r (\boldsymbol x)+p_g (\boldsymbol x)}=0.5 D=pr(x)+pg(x)pr(x)=0.5


3. 纳什均衡点

 通过上面的推导,我们可以总结出生成网络G最终将收敛到真是分布,即:
p g = p r p_g=p_r pg=pr
此时生成的样本与真实样本来自统一分部,真假难辨,在判别器中均由相同的概率判定为真或假,即
D ( ⋅ ) = 0.5 D(\cdot)=0.5 D()=0.5
此时损失函数为
L ( G ∗ , D ∗ ) = − 2 log⁡ 2 \mathcal L(G^*,D^* )=-2\text{log⁡}2 L(G,D)=2log⁡2

  • 6
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值