在学习 Variational Auto-Encoder 时,同时注意到了 GAN 研究的火热。但当时觉得 GAN 非常不成熟(训练不稳定,依赖各种说不清的 tricks;没有有效的监控指标,需要大量的人工判断,因此难以扩展到图像之外的高维数据)。在读了 Goodfellow 的 tutorial 后[2],开始黑转路人,觉得 GAN 虽然缺点不少,但优点也很明显。WGAN[5, 6] 等工作出现后,开始逐渐路人转粉,对 GAN 产生了兴趣。
这里,我们仅仅从直观上讨论GAN框架及相关变种,将理论留待将来讨论。
1. Basic GAN
本质上,GAN 是一种训练模式,而非一种待定的网络结构[1]。
图1. GAN基本框架【src】
GAN 的基本思想是,生成器和判别器玩一场“道高一尺,魔高一丈”的游戏:判别器要练就“火眼金睛”,尽量区分出真实的样本(如真实的图片)和由生成器生成的假样本;生成器要学着“以假乱真”,生成出使判别器判别为真实的“假样本”。
竞争的理想怦是双方都不断进步——(理想情况下)判别器的眼睛越发“雪亮”,生成器的欺骗能力也不断提高。对抗的胜负无关紧要,重要的是,最后生成器的欺骗能力足够好,能够生成与真实样本足够相似的样本——直观而言,生成的样本看起来像是训练集(如图片)的样本;形式化的,生成器生成样本的分布,应该与训练集样本分布接近。
理论上可以,在理想条件下,生成器是可以通过这种对抗得到目标分布的(即生成足够真实的样本)。
假设要训练数据为灰度 MNIST(归一化为[0, 1]之间),生成器(generator)可以为任意输入为隐变量维度,输出为 1 x 28 x 28的模型。一个示例模型定义如下:
def build_generator(latent_size):
model = Sequential()
model.add(Dense(1024, input_dim=latent_size, activation='relu'))
model.add(Dense(28 * 28, activation='tanh'))
model.add(Reshape((1, 28, 28)))
return model
判别器(discriminator)可以为任意输入 1 x 28 x 28,输出为1维且在 [0, 1] 之间(经过sigmoid激活)的模型。一个示例模型定义如下:
def build_discriminator():
model = Sequential()
model.add(Flatten(input_shape=(1, 28, 28)))
model.add(Dense(256, activation='relu'))
model.add(Dense(128, activation='relu'))
model.add(Dense(1), activation='sigmoid')
return model
输出值表示判别器判别输入样本为真的概率。即输出值越接近1,判别器越确信样本为真;输出值越接近0,判别器越确信样本为假。
判别器
L
D
=
−
Σ
i
log
(
D
(
x
i
)
)
−
Σ
i
log
(
1
−
D
(
G
(
z
i
)
)
)
L_D = -\Sigma_i \log(D(\textbf{x}_i)) -\Sigma_i \log(1-D(G(\textbf{z}_i)))
LD=−Σilog(D(xi))−Σilog(1−D(G(zi)))
判别器的训练的目标为:对于真实样本,输出尽量接近1;对于生成器生成的假样本,输出尽量接近0。
也即训练判别器时,真实样本的标签为1,生成样本的标签为0。
生成器
L
G
=
Σ
i
log
(
1
−
D
(
G
(
z
i
)
)
)
L_G = \Sigma_i \log(1-D(G(\textbf{z}_i)))
LG=Σilog(1−D(G(zi)))
判别器的训练的目标为生成的假样本,使判别器的输出尽量接近1,即尽量以假乱真。
为了解决训练过程中,梯度消失的问题,一般使用如下损失函数 (Trick 2):
L
G
=
−
Σ
i
log
(
D
(
G
(
z
i
)
)
)
L_G = -\Sigma_i \log(D(G(\textbf{z}_i)))
LG=−Σilog(D(G(zi)))
为使用这个损失函数,只需要将生成样本的标签为1,同时使用变通的交叉熵损失函数。
GAN 的训练流程如下[1]:
∇
θ
1
m
Σ
i
=
1
m
−
log
(
D
(
G
(
z
(
i
)
)
)
)
\nabla\theta\frac{1}{m}\Sigma_{i=1}^{m}-\log(D(G(\textbf{z}^{(i)})))
∇θm1Σi=1m−log(D(G(z(i))))
GAN足够简单,也有理论上的保证。但在实践中,需要许多技巧和运气才能正常把“游戏玩下去”。这里,我们不考虑理论,而是关注不要GAN变种在损失函数设计的差异。
2. Least Squares GAN
我们以[4]中 Eq (9) 为例来介绍 LSGAN。其中判别器的定义如下:
def build_discriminator():
model = Sequential()
model.add(Flatten(input_shape=(1, 28, 28)))
model.add(Dense(256, activation='relu'))
model.add(Dense(128, activation='relu'))
model.add(Dense(1), activation='linear') ## change 1
return model
与basic GAN 唯一不同在判别器的最后输出不使用 sigmoid 激活,而是使用了线性函数(也即不使用激活)(第6行 change 1)。
有了生成器和判别器的定义,我们来实际构造两者以用于训练:
# 构造判别器
disc = build_discriminator()
disc.compile(optimizer=Adam(lr=lr),loss='mse')
# 构建生成器
generator = build_generator(latent_size)
latent = Input(shape=(latent_size,))
# 生成假图片
fake = generator(latent)
# 我们要训练生成器,因此固定判别的权值不变
disc.trainable = False
fake = disc(fake)
combined = Model(input=latent, output=fake)
combined.compile(optimizer=Adam(lr=lr), loss='mse')
不同于basic GAN, LSGAN的训练损失函数由交叉熵改为MSE(Mean Squared Error)。
for epoch in range(nb_epochs):
for index in range(nb_batches):
## 1) 训练判别器
# 1.1采样隐变量并生成假样本
noise = np.random.uniform(-1, 1, (batch_size, latent_size))
generated_images = generator.predict(noise, verbose=0)
# 1.2 从训练中采样真实样本
image_batch = X_train[index * batch_size:(index + 1) * batch_size]
label_batch = y_train[index * batch_size:(index + 1) * batch_size]
# 利用真假数据进行训练
X = np.concatenate((image_batch, generated_images))
# 设定真假数据的损失,a == 0, b == 1
y = np.array([1] * len(image_batch) + [0] * batch_size)
disc.train_on_batch(X, y)
## 2)训练生成器
# 采样隐变量
noise = np.random.uniform(-1, 1, (batch_size, latent_size))
target = np.ones(batch_size) # 设定生成样本的损失 c == b == 1
combined.train_on_batch(noise, target)
图2是训练过程中,由生成器采样的几张示例图片。完整的示例可以参见repo。
图2. LSGAN随机采样生成的图片(Epoch: 443)
由于仅作为示例以及时间和计算资源的限制,从模型结构到优化器的参数都没有经过任何调优。因此,这里生成的图片的质量不应该做为算法优劣的依据(下同)。
3. Wasserstein GAN(WGAN)
WGAN采用线性的损失函数,为此我们定义:
def dummy_loss(loss_to_backprop, y_pred):
return K.mean(loss_to_backprop * y_pred) # delta == loss_to_backprop
disc.compile(optimizer=Adam(lr=lr),loss=dummy_loss)
combined.compile(optimizer=Adam(lr=lr), loss=dummy_loss)
为应用这个损失函数,代码更改如下(第12和18行,change 2、3)。
for epoch in range(nb_epochs):
for index in range(nb_batches):
## 1) 训练判别器
# 1.1采样隐变量并生成假样本
noise = np.random.uniform(-1, 1, (batch_size, latent_size))
generated_images = generator.predict(noise, verbose=0)
# 1.2 从训练中采样真实样本
image_batch = X_train[index * batch_size:(index + 1) * batch_size]
label_batch = y_train[index * batch_size:(index + 1) * batch_size]
# 利用真假数据进行训练
X = np.concatenate((image_batch, generated_images))
y = np.array([-1] * len(image_batch) + [1] * batch_size) ## change 2
disc.train_on_batch(X, y)
## 2)训练生成器
# 采样隐变量
noise = np.random.uniform(-1, 1, (batch_size, latent_size))
target = -np.ones(batch_size) ## change 3
combined.train_on_batch(noise, target)
WGAN有如下突出优点[6]:
- 训练稳定,不需要平稳生成器和判别器。
- loss 值与生成样本质量相关,可以用来监督训练进程,不需要人工判断干预。
完整的示例可以参见repo。读者可以自行验证,D_loss 及生成图像的质量变化。
4. GLSGAN
[7]提出了 Loss Sensitive GAN,并随后发现,可以和WGAN在统一的框架下研究,即 generalized LSGAN(图3)。
图3. 【src】
GLSGAN 使用 LeakyReLU作用激活,其中
s
∈
(
−
∞
,
1
]
s \in (-\infty, 1]
s∈(−∞,1]。
KaTeX parse error: Unknown column alignment: x at position 37: …\begin{array} x̲x, \ if\ x \ge…
def build_discriminator():
model = Sequential()
model.add(Flatten(input_shape=(1, 28, 28)))
model.add(Dense(256, activation='relu'))
model.add(Dense(128, activation='relu'))
model.add(Dense(1), activation='linear')
model.add(LeakyReLU(slope)) ##
return model
下面是不同s下,训练的模型生成的示例图片。
图4. Slope: 1 (WGAN), Epoch: 170
图5. Slope: 0 (Loss Sensitive GAN), Epoch: 189
图6. Slope: -1 (L1 Loss), Epoch: 399
非线性损失
GLSGAN并不限定损失函数为(分段)线性。这里使用 Exponential Linear Unit(ELU)。
KaTeX parse error: Unknown column alignment: x at position 31: …\begin{array} x̲x, \ if\ x \ge…
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-dILAbTgm-1572062021724)(https://ai2-s2-public.s3.amazonaws.com/figures/2016-11-08/0373b97580cdfd0b69f165e1a946bae62da95dce/1-Figure2-1.png)]
图7. Exponential Linear Unit vs. ReLU【src】
def build_discriminator():
# Other Code goes here...
model.add(ELU) # Exponential Linear Unit
return model
图8. ELU, Epoch: 367
一个完整的示例见repo。基于 torch 的official repo。
5. 讨论
损失函数
损失函数唯一重要的地方在于,不断驱动两个网络的竞争。直观上,判别网络将真实样本和生成样本,向坐标轴上的两个不同的区域移动。
- 对于basic GAN,这两个区域分别分别是0(生成)和1(真实),使用的损失函数是对数函数( f ( x ) = − log ( x ) f(x)=-\log(x) f(x)=−log(x))(即交叉熵)[1]。
- 对于 Least Squared GAN,这两个区域分别是 a 和 b(a < b),使用的损失函数是二次函数( f ( x ) = x 2 f(x)= x^2 f(x)=x2)[4]。
- 对于 WGAN,这两个区域分别 + ∞ +\infty +∞(真实样本)和 − ∞ -\infty −∞(生成样本),使用的是线性损失函数( f ( x ) = x f(x)=x f(x)=x)[6]。
- 对于 Loss-Sensitive GAN,这两个区域分别是 + ∞ +\infty +∞(真实样本)和 ( − ∞ , 0 ] (-\infty, 0] (−∞,0][7],使用的是ReLu损失函数。
- 对于 Generalized LSGAN( 0 < γ < 1 0 < \gamma < 1 0<γ<1),这两个区域分别是 + ∞ +\infty +∞(真实样本)和 ( − ∞ , 0 ] (-\infty, 0] (−∞,0][7],使用的是Leaky ReLu损失函数。
- 对于 Generalized LSGAN( γ < 0 \gamma < 0 γ<0),这两个区域分别是 + ∞ +\infty +∞(真实样本)和 0 0 0[7],使用的是分段线性的损失函数。
从损失函数的角度,Basic GAN几乎选择了一个最差的方案——经过sigmoid激活后,损失函数在0-1两端都存在饱和区。
关于GLSGAN
当 γ < 0 \gamma < 0 γ<0 时,从形式上,GLSGAN 其实已经不能叫做 Loss Sensitive了。因为此时 GLSGAN 的行为更向是 Least Squares GAN——将生成样本向某个点推(零点)。不过 GLSGAN 对于真实样本更激进,它会不断将真实样本向 + ∞ +\infty +∞ 推。另一个不同是,GLSGAN 使用线性的函数,而 LSGAN 使用二次函数。
TODO 此处有一个疑问待解决:文章中说 Least Squares GAN也存在梯度消失的问题。从形式上看,虽然一次函数在极值附近梯度接近0,但由于正负样本的损失函数的极值点不同,因此,直觉上,在对抗训练过程中应该不会出现梯度消失的现象。看到需要进一步提高理论修养。
Regularities
这里我们没有关注正则性约束,但 WGAN, GLSGAN 要求判别器是 Lipschitz(相对于模型参数)。直观上,Lipschitz 保证训练过程中,不会因为参数更新引起模型的跳跃性变化,确保训练过程平稳。
6. 结语
- 形式上,各种方法仅仅是损失函数不太一样,但损失函数的选择并不trivial。basic GAN 训练困难已经表明了 GAN 对抗的训练方式对损失函数的非常的敏感。没有严谨的理论支撑,随意的损失函数并不能保证训练如预期进行(收敛且稳定)。
- 鲁棒的 GAN 训练方法对于 GAN 在广阔领域的应用将是非常大的推动力(如最近的压缩感知应用)。
- 对不同损失函数(不同 GAN)的性质,目前还缺少系统性的比较研究,期待更新的研究结果。
References
- Ian Goodfellow et al. (2014). Generative Adversarial Networks.
- Ian Goodfellow. (2016). NIPS 2016 Tutorial: Generative Adversarial Networks.
- Nowozin et al. (2016). f-GAN: Training Generative Neural Samplers using Variational Divergence Minimization.
- Mao et al. (2016). Least Squares Generative Adversarial Networks.
- Arjovsky et al. (2016). Towards Principled Methods for Training Generative Adversarial Networks.
- Arjovsky et al. (2017). Wasserstein GAN.
- Qi. (2017). Loss-Sensitive Generative Adversarial Networks on Lipschitz Densities.
- An Incomplete Map of the GAN models.
- LS-GAN:把GAN建立在Lipschitz密度上.
- 广义LS-GAN(GLS-GAN) .
Further Reading
- Arora et al. (2017). Generalization and Equilibrium in Generative Adversarial Nets (GANs).
- Arora et al. (2017). Do GANs actually learn the distribution? An empirical study.
- Odena. Open Questions about Generative Adversarial Networks(very insightful discussion)