Contents
Introduction
- WGAN 增加了 GAN 模型训练的稳定性,但有时仍然会有生成质量不高或难以收敛的问题。作者发现上述问题经常是由 WGAN 中的 weight clipping 导致的,因此作者提出了 WGAN-GP,用 gradient penalty 代替 weight clipping 来给 discriminator 施加 Lipschitz 约束
- WGAN-GP 的生成性能和训练稳定性都超过了 WGAN
Difficulties with weight constraints
Capacity underuse
Exploding and vanishing gradients
Gradient penalty
- A differentiable function is 1-Lipschitz if and only if it has gradients with norm less than or equal to 1 everywhere.
D ∈ 1 − Lipschitz ⟺ ∥ ∇ x D ( x ) ∥ ≤ 1 for all x D \in 1-\text { Lipschitz } \Longleftrightarrow\left\|\nabla_x D(x)\right\| \leq 1 \text { for all } \mathrm{x} D∈1− Lipschitz ⟺∥∇xD(x)∥≤1 for all x此外,作者还给出了 Properties of the optimal WGAN critic
由引理 1 可知,最优的 discriminator f f f 满足 gradient norm 几乎处处为 1. - 鉴于上述性质,作者直接对 discriminator 的输出相对于输入的梯度范数进行了约束,损失函数为
L = E x ~ ∼ P g [ D ( x ~ ) ] − E x ∼ P r [ D ( x ) ] ⏟ Original critic loss + λ E x ^ ∼ P x ~ [ ( ∥ ∇ x ^ D ( x ^ ) ∥ 2 − 1 ) 2 ] . ⏟ Our gradient penalty L=\underbrace{\underset{\tilde{\boldsymbol{x}} \sim \mathbb{P}_g}{\mathbb{E}}[D(\tilde{\boldsymbol{x}})]-\underset{\boldsymbol{x} \sim \mathbb{P}_r}{\mathbb{E}}[D(\boldsymbol{x})]}_{\text {Original critic loss }}+\underbrace{\lambda \underset{\hat{\boldsymbol{x}} \sim \mathbb{P}_{\boldsymbol{\tilde x}}}{\mathbb{E}}\left[\left(\left\|\nabla_{\hat{\boldsymbol{x}}} D(\hat{\boldsymbol{x}})\right\|_2-1\right)^2\right] .}_{\text {Our gradient penalty }} L=Original critic loss x~∼PgE[D(x~)]−x∼PrE[D(x)]+Our gradient penalty λx^∼Px~E[(∥∇x^D(x^)∥2−1)2].其中 P g \mathbb P_g Pg 为生成的数据分布, P r \mathbb P_r Pr 为真实数据分布,后一项为在 WGAN 基础上新添加的梯度惩罚项,使得 discriminator 倾向于满足 ∣ ∣ ∇ x ^ D ( x ^ ) ∣ ∣ = 1 ||\nabla_{\hat x}D(\hat x)||=1 ∣∣∇x^D(x^)∣∣=1 for all x ^ \hat x x^. 作者选取 λ = 10 \lambda=10 λ=10. 另外注意到作者施加的梯度惩罚是 Two-sided penalty,即只是要求梯度范数接近 1 而不要求其一定是小于 1,在实际实验中,作者发现 two-sided penalty 比 owo-sided penalty 表现更好 - Sampling distribution.
D
(
x
^
)
D(\hat x)
D(x^) 的定义域是整个图像空间,而在整个图像空间上施加 Lipschitz constraint 是不现实的,因此作者只是在
P
x
^
\mathbb P_{\boldsymbol{\hat x}}
Px^ 上采样,
P
x
^
\mathbb P_{\boldsymbol{\hat x}}
Px^ 被定义为 sampling uniformly along straight lines between pairs of points sampled from the data distribution
P
r
\mathbb P_r
Pr and the generator distribution
P
g
\mathbb P_g
Pg,即从
P
r
\mathbb P_r
Pr 和
P
g
\mathbb P_g
Pg 中各采样出一个点,把这两个点相连,在这两个点中间做一个 random sample. (作者的 motivation 来自 Proposition 1,the optimal critic contains straight lines with gradient norm 1 connecting coupled points from
P
r
\mathbb P_r
Pr and
P
g
\mathbb P_g
Pg.) 只采样
P
r
\mathbb P_r
Pr 和
P
g
\mathbb P_g
Pg 之间的点在实验中有较好的性能,并且直观上看,由于我们在训练时是想把
P
g
\mathbb P_g
Pg 往
P
r
\mathbb P_r
Pr 方向移动,因此采样它们中间的点也比较合理,可以使得 generator 在将
P
g
\mathbb P_g
Pg 往
P
r
\mathbb P_r
Pr 方向移动时,discriminator 提供更有意义的梯度信息
- No critic batch normalization. 大多数 GAN 在 generator 和 discriminator 里都用了 BN 来帮助稳定训练,但作者认为 BN 将 discriminator 的任务从 mapping a single input to a single output 变为了 mapping from an entire batch of inputs to a batch of outputs,而作者提出的梯度惩罚是针对单个输入样本而非一整个 batch 的,因此惩罚项就没那么有效了。为此,作者没有在 discriminator 里使用 BN,并且建议使用 LN 代替 BN
Meaningful loss curves and detecting overfitting