最近在学习GAN相关的知识,现在对GAN有了比较清晰的了解,希望在这里给大家分享一下我的理解。并且本文避开了许多较难的数学部分,旨在给初学者一些直观的了解。
GAN
首先是原始的GAN,我们直接看GAN优化的函数:
min
G
max
D
V
(
D
,
G
)
=
E
x
∼
p
data
(
x
)
[
log
D
(
x
)
]
+
E
z
∼
p
z
(
z
)
[
log
(
1
−
D
(
G
(
z
)
)
)
]
\min _{G} \max _{D} V(D, G)=E_{x \sim p_{\text {data }}(x)}[\log D(x)]+E_{z \sim p_{z}(z)}[\log (1-D(G(z)))]
GminDmaxV(D,G)=Ex∼pdata (x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
其中,G表示生成器,D表示判别器,data,x可以理解成真实样本的样本空间和真实样本,z表示输入的噪声。
接下来这句话是理解这个优化函数关键:这里假设真实样本标签为1,虚假的生成样本标签为0.
好了,接下来,我们分开看生成器G和判别器G。
生成器G
生成器G是为了使该优化函数最小,而我们在训练生成器G时,控制判别器D权重不变,那我们只需要看函数的第二部分,相当于优化:
min
G
E
z
∼
p
z
(
z
)
[
log
(
1
−
D
(
G
(
z
)
)
)
\min _{G} E_{z \sim p_{z}(z)}[\log (1-D(G(z)))
GminEz∼pz(z)[log(1−D(G(z)))
这就很好理解了,想让这个函数最小,那么生成器就希望 D(G(z))=1,意思就是生成器G希望判别器D把生成的图片G(z)判别为1,也就是判别为真样本的意思。这样就成功糊弄了判别器D。
判别器D
反过来,我们在训练判别器D时,控制生成器G的权重不变,优化函数为:
max
D
(
E
x
∼
p
data
(
x
)
[
log
D
(
x
)
]
+
E
z
∼
p
z
(
z
)
[
log
(
1
−
D
(
G
(
z
)
)
)
]
)
\max _{D} (E_{x \sim p_{\text {data }}(x)}[\log D(x)]+E_{z \sim p_{z}(z)}[\log (1-D(G(z)))])
Dmax(Ex∼pdata (x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))])
log函数的可行域是大于0的,所以我们必须保证判别器D的输出在0和1之间。那我们希望优化函数最大就是希望D(x)=1,且D(G(z))=0。意思就是希望判别器判断真实样本x为真样本,生成出来的虚假样本G(z)为假样本。
了解了原始的GAN,让我们一起来看GAN的进阶版本WGAN。
WGAN(WGAN-GP)
网上很多关于WGAN的教程都是,推土机距离,Kantorovich-Rubinstein duality理论,Lipschitz约束等等概念,本文完全避开这些概念,给读者最为直观的理解。
就像上文提到的那样,判别器D的输出范围是0-1之间,这样就会带来一些弊端。比如生成器生成的图像较差时,第二次生成的图像明明比第一次进步了,却都被判别器打了接近0分。这明显是不公平的,这样下去生成器就不知道怎么进步了。于是乎WGAN就直接把优化函数的log去掉了,优化函数变为:
min
G
max
D
V
(
D
,
G
)
=
E
x
∼
p
data
(
x
)
[
D
(
x
)
]
+
E
z
∼
p
z
(
z
)
[
1
−
D
(
G
(
z
)
)
]
\min _{G} \max _{D} V(D, G)=E_{x \sim p_{\text {data }}(x)}[D(x)]+E_{z \sim p_{z}(z)}[1-D(G(z))]
GminDmaxV(D,G)=Ex∼pdata (x)[D(x)]+Ez∼pz(z)[1−D(G(z))]
再进一步简洁一些就是:
min
G
max
D
V
(
D
,
G
)
=
E
x
∼
p
data
(
x
)
[
D
(
x
)
]
−
E
z
∼
p
z
(
z
)
[
D
(
G
(
z
)
)
]
\min _{G} \max _{D} V(D, G)=E_{x \sim p_{\text {data }}(x)}[D(x)]-E_{z \sim p_{z}(z)}[D(G(z))]
GminDmaxV(D,G)=Ex∼pdata (x)[D(x)]−Ez∼pz(z)[D(G(z))]
现在好处是判别器可以在实数范围内任意打分了,但是,问题又出现了,没有了0-1的输出约束,忽大忽小的打分,会使得梯度非常大,从而使判别器训练振荡而不收敛。这时候就引入了权重修剪(weight clipping)策略,就是把判别器的权重限制在一个区间里,比如[-0.01,0.01],以此来确保判别器的权重不会发生大的振荡变化。
当然除了权重修剪(weight clipping)策略以外,还有梯度约束(gradient penalty)策略。就是在损失函数后面加了个惩罚项:
min
G
max
D
V
(
D
,
G
)
=
E
x
∼
p
data
(
x
)
[
D
(
x
)
]
−
E
z
∼
p
z
(
z
)
[
D
(
G
(
z
)
)
]
+
λ
(
∥
∇
x
^
D
(
x
^
)
∥
2
−
1
)
2
\min _{G} \max _{D} V(D, G)=E_{x \sim p_{\text {data }}(x)}[D(x)]-E_{z \sim p_{z}(z)}[D(G(z))]+\lambda\left(\left\|\nabla_{\hat{x}} D(\hat{x})\right\|_{2}-1\right)^{2}
GminDmaxV(D,G)=Ex∼pdata (x)[D(x)]−Ez∼pz(z)[D(G(z))]+λ(∥∇x^D(x^)∥2−1)2
其中
x
^
←
ϵ
x
+
(
1
−
ϵ
)
G
(
z
)
\hat{x} \leftarrow \epsilon x+(1-\epsilon) G(z)
x^←ϵx+(1−ϵ)G(z)
加上这个惩罚项的意思就是希望判别器在样本空间
x
^
\hat{x}
x^内(
x
^
\hat{x}
x^包括真实样本x,生成的假样本G(z)以及两者之间空间),梯度的L2范数趋近于1。换句话说就是为了使梯度保持在一个合理较小的值,避免判别器在训练时出现大的振荡而不收敛的情况。