【Learning Notes】生成式对抗网络(Generative Adversarial Networks,GAN)

在学习 Variational Auto-Encoder 时,同时注意到了 GAN 研究的火热。但当时觉得 GAN 非常不成熟(训练不稳定,依赖各种说不清的 tricks;没有有效的监控指标,需要大量的人工判断,因此难以扩展到图像之外的高维数据)。在读了 Goodfellow 的 tutorial 后[2],开始黑转路人,觉得 GAN 虽然缺点不少,但优点也很明显。WGAN[5, 6] 等工作出现后,开始逐渐路人转粉,对 GAN 产生了兴趣。

这里,我们仅仅从直观上讨论GAN框架及相关变种,将理论留待将来讨论。

1. Basic GAN

本质上,GAN 是一种训练模式,而非一种待定的网络结构[1]。

图1. GAN基本框架【src

GAN 的基本思想是,生成器和判别器玩一场“道高一尺,魔高一丈”的游戏:判别器要练就“火眼金睛”,尽量区分出真实的样本(如真实的图片)和由生成器生成的假样本;生成器要学着“以假乱真”,生成出使判别器判别为真实的“假样本”。

竞争的理想怦是双方都不断进步——(理想情况下)判别器的眼睛越发“雪亮”,生成器的欺骗能力也不断提高。对抗的胜负无关紧要,重要的是,最后生成器的欺骗能力足够好,能够生成与真实样本足够相似的样本——直观而言,生成的样本看起来像是训练集(如图片)的样本;形式化的,生成器生成样本的分布,应该与训练集样本分布接近。

理论上可以,在理想条件下,生成器是可以通过这种对抗得到目标分布的(即生成足够真实的样本)。

假设要训练数据为灰度 MNIST(归一化为[0, 1]之间),生成器(generator)可以为任意输入为隐变量维度,输出为 1 x 28 x 28的模型。一个示例模型定义如下:

def build_generator(latent_size):
	model = Sequential()
	model.add(Dense(1024, input_dim=latent_size, activation='relu'))
	model.add(Dense(28 * 28, activation='tanh'))
	model.add(Reshape((1, 28, 28)))
	return model

判别器(discriminator)可以为任意输入 1 x 28 x 28,输出为1维且在 [0, 1] 之间(经过sigmoid激活)的模型。一个示例模型定义如下:

def build_discriminator():
	model = Sequential()
	model.add(Flatten(input_shape=(1, 28, 28)))
	model.add(Dense(256, activation='relu'))
	model.add(Dense(128, activation='relu'))
	model.add(Dense(1), activation='sigmoid')	
	return model

输出值表示判别器判别输入样本为真的概率。即输出值越接近1,判别器越确信样本为真;输出值越接近0,判别器越确信样本为假。

判别器

L D = − Σ i log ⁡ ( D ( x i ) ) − Σ i log ⁡ ( 1 − D ( G ( z i ) ) ) L_D = -\Sigma_i \log(D(\textbf{x}_i)) -\Sigma_i \log(1-D(G(\textbf{z}_i))) LD=Σilog(D(xi))Σilog(1D(G(zi)))
判别器的训练的目标为:对于真实样本,输出尽量接近1;对于生成器生成的假样本,输出尽量接近0。
也即训练判别器时,真实样本的标签为1,生成样本的标签为0。

生成器

L G = Σ i log ⁡ ( 1 − D ( G ( z i ) ) ) L_G = \Sigma_i \log(1-D(G(\textbf{z}_i))) LG=Σilog(1D(G(zi)))
判别器的训练的目标为生成的假样本,使判别器的输出尽量接近1,即尽量以假乱真。
为了解决训练过程中,梯度消失的问题,一般使用如下损失函数 (Trick 2):
L G = − Σ i log ⁡ ( D ( G ( z i ) ) ) L_G = -\Sigma_i \log(D(G(\textbf{z}_i))) LG=Σilog(D(G(zi)))
为使用这个损失函数,只需要将生成样本的标签为1,同时使用变通的交叉熵损失函数。

GAN 的训练流程如下[1]:
这里写图片描述
∇ θ 1 m Σ i = 1 m − log ⁡ ( D ( G ( z ( i ) ) ) ) \nabla\theta\frac{1}{m}\Sigma_{i=1}^{m}-\log(D(G(\textbf{z}^{(i)}))) θm1Σi=1mlog(D(G(z(i))))

GAN足够简单,也有理论上的保证。但在实践中,需要许多技巧和运气才能正常把“游戏玩下去”。这里,我们不考虑理论,而是关注不要GAN变种在损失函数设计的差异。

2. Least Squares GAN

我们以[4]中 Eq (9) 为例来介绍 LSGAN。其中判别器的定义如下:

def build_discriminator():
	model = Sequential()
	model.add(Flatten(input_shape=(1, 28, 28)))
	model.add(Dense(256, activation='relu'))
	model.add(Dense(128, activation='relu'))
	model.add(Dense(1), activation='linear') ## change 1	
	return model

与basic GAN 唯一不同在判别器的最后输出不使用 sigmoid 激活,而是使用了线性函数(也即不使用激活)(第6行 change 1)。

有了生成器和判别器的定义,我们来实际构造两者以用于训练:

# 构造判别器
disc = build_discriminator()
disc.compile(optimizer=Adam(lr=lr),loss='mse')

# 构建生成器
generator = build_generator(latent_size)
latent = Input(shape=(latent_size,))
# 生成假图片
fake = generator(latent)
# 我们要训练生成器,因此固定判别的权值不变
disc.trainable = False
fake = disc(fake)
combined = Model(input=latent, output=fake)
combined.compile(optimizer=Adam(lr=lr), loss='mse')

不同于basic GAN, LSGAN的训练损失函数由交叉熵改为MSE(Mean Squared Error)。

for epoch in range(nb_epochs):
    for index in range(nb_batches):
        ## 1) 训练判别器 
        # 1.1采样隐变量并生成假样本
        noise = np.random.uniform(-1, 1, (batch_size, latent_size))
        generated_images = generator.predict(noise, verbose=0)
        # 1.2 从训练中采样真实样本
        image_batch = X_train[index * batch_size:(index + 1) * batch_size]
        label_batch = y_train[index * batch_size:(index + 1) * batch_size]        
        # 利用真假数据进行训练
        X = np.concatenate((image_batch, generated_images))
        # 设定真假数据的损失,a == 0, b == 1
        y = np.array([1] * len(image_batch) + [0] * batch_size)
        disc.train_on_batch(X, y)

        ## 2)训练生成器
        # 采样隐变量       
        noise = np.random.uniform(-1, 1, (batch_size, latent_size))
        target = np.ones(batch_size) # 设定生成样本的损失 c == b == 1
        combined.train_on_batch(noise, target)

图2是训练过程中,由生成器采样的几张示例图片。完整的示例可以参见repo

LS-GAN
图2. LSGAN随机采样生成的图片(Epoch: 443)

由于仅作为示例以及时间和计算资源的限制,从模型结构到优化器的参数都没有经过任何调优。因此,这里生成的图片的质量不应该做为算法优劣的依据(下同)。

3. Wasserstein GAN(WGAN)

WGAN采用线性的损失函数,为此我们定义:

 def dummy_loss(loss_to_backprop, y_pred):
	return K.mean(loss_to_backprop * y_pred) # delta == loss_to_backprop

disc.compile(optimizer=Adam(lr=lr),loss=dummy_loss)
combined.compile(optimizer=Adam(lr=lr), loss=dummy_loss)

为应用这个损失函数,代码更改如下(第12和18行,change 2、3)。

for epoch in range(nb_epochs):
    for index in range(nb_batches):
        ## 1) 训练判别器 
        # 1.1采样隐变量并生成假样本
        noise = np.random.uniform(-1, 1, (batch_size, latent_size))
        generated_images = generator.predict(noise, verbose=0)
        # 1.2 从训练中采样真实样本
        image_batch = X_train[index * batch_size:(index + 1) * batch_size]
        label_batch = y_train[index * batch_size:(index + 1) * batch_size]        
        # 利用真假数据进行训练
        X = np.concatenate((image_batch, generated_images))        
        y = np.array([-1] * len(image_batch) + [1] * batch_size) ## change 2
        disc.train_on_batch(X, y)

        ## 2)训练生成器
        # 采样隐变量       
        noise = np.random.uniform(-1, 1, (batch_size, latent_size))
        target = -np.ones(batch_size) ## change 3
        combined.train_on_batch(noise, target)

WGAN有如下突出优点[6]:

  • 训练稳定,不需要平稳生成器和判别器。
  • loss 值与生成样本质量相关,可以用来监督训练进程,不需要人工判断干预。

完整的示例可以参见repo。读者可以自行验证,D_loss 及生成图像的质量变化。

4. GLSGAN

[7]提出了 Loss Sensitive GAN,并随后发现,可以和WGAN在统一的框架下研究,即 generalized LSGAN(图3)。

这里写图片描述
图3. 【src

GLSGAN 使用 LeakyReLU作用激活,其中 s ∈ ( − ∞ , 1 ] s \in (-\infty, 1] s,1]
KaTeX parse error: Unknown column alignment: x at position 37: …\begin{array} x̲x, \ if\ x \ge…

def build_discriminator():
	model = Sequential()
	model.add(Flatten(input_shape=(1, 28, 28)))
	model.add(Dense(256, activation='relu'))
	model.add(Dense(128, activation='relu'))
	model.add(Dense(1), activation='linear')
	model.add(LeakyReLU(slope))	##
	return model

下面是不同s下,训练的模型生成的示例图片。

slope==1, WGAN
图4. Slope: 1 (WGAN), Epoch: 170

slope==0, LS-GAN
图5. Slope: 0 (Loss Sensitive GAN), Epoch: 189

slope==-1, L1
图6. Slope: -1 (L1 Loss), Epoch: 399

非线性损失

GLSGAN并不限定损失函数为(分段)线性。这里使用 Exponential Linear Unit(ELU)。
KaTeX parse error: Unknown column alignment: x at position 31: …\begin{array} x̲x, \ if\ x \ge…
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-dILAbTgm-1572062021724)(https://ai2-s2-public.s3.amazonaws.com/figures/2016-11-08/0373b97580cdfd0b69f165e1a946bae62da95dce/1-Figure2-1.png)]
图7. Exponential Linear Unit vs. ReLU【src

def build_discriminator():
	# Other Code goes here...
	model.add(ELU)	# Exponential Linear Unit
	return model

ELU
图8. ELU, Epoch: 367

一个完整的示例见repo。基于 torch 的official repo

5. 讨论

损失函数

损失函数唯一重要的地方在于,不断驱动两个网络的竞争。直观上,判别网络将真实样本和生成样本,向坐标轴上的两个不同的区域移动。

  • 对于basic GAN,这两个区域分别分别是0(生成)和1(真实),使用的损失函数是对数函数( f ( x ) = − log ⁡ ( x ) f(x)=-\log(x) f(x)=log(x))(即交叉熵)[1]。
  • 对于 Least Squared GAN,这两个区域分别是 a 和 b(a < b),使用的损失函数是二次函数( f ( x ) = x 2 f(x)= x^2 f(x)=x2)[4]。
  • 对于 WGAN,这两个区域分别 + ∞ +\infty +(真实样本)和 − ∞ -\infty (生成样本),使用的是线性损失函数( f ( x ) = x f(x)=x f(x)=x)[6]。
  • 对于 Loss-Sensitive GAN,这两个区域分别是 + ∞ +\infty +(真实样本)和 ( − ∞ , 0 ] (-\infty, 0] (,0][7],使用的是ReLu损失函数。
  • 对于 Generalized LSGAN( 0 < γ < 1 0 < \gamma < 1 0<γ<1),这两个区域分别是 + ∞ +\infty +(真实样本)和 ( − ∞ , 0 ] (-\infty, 0] (,0][7],使用的是Leaky ReLu损失函数。
  • 对于 Generalized LSGAN( γ < 0 \gamma < 0 γ<0),这两个区域分别是 + ∞ +\infty +(真实样本)和 0 0 0[7],使用的是分段线性的损失函数。

从损失函数的角度,Basic GAN几乎选择了一个最差的方案——经过sigmoid激活后,损失函数在0-1两端都存在饱和区。

关于GLSGAN

γ < 0 \gamma < 0 γ<0 时,从形式上,GLSGAN 其实已经不能叫做 Loss Sensitive了。因为此时 GLSGAN 的行为更向是 Least Squares GAN——将生成样本向某个点推(零点)。不过 GLSGAN 对于真实样本更激进,它会不断将真实样本向 + ∞ +\infty + 推。另一个不同是,GLSGAN 使用线性的函数,而 LSGAN 使用二次函数。

TODO 此处有一个疑问待解决:文章中说 Least Squares GAN也存在梯度消失的问题。从形式上看,虽然一次函数在极值附近梯度接近0,但由于正负样本的损失函数的极值点不同,因此,直觉上,在对抗训练过程中应该不会出现梯度消失的现象。看到需要进一步提高理论修养。

Regularities

这里我们没有关注正则性约束,但 WGAN, GLSGAN 要求判别器是 Lipschitz(相对于模型参数)。直观上,Lipschitz 保证训练过程中,不会因为参数更新引起模型的跳跃性变化,确保训练过程平稳。

6. 结语

  • 形式上,各种方法仅仅是损失函数不太一样,但损失函数的选择并不trivial。basic GAN 训练困难已经表明了 GAN 对抗的训练方式对损失函数的非常的敏感。没有严谨的理论支撑,随意的损失函数并不能保证训练如预期进行(收敛且稳定)。
  • 鲁棒的 GAN 训练方法对于 GAN 在广阔领域的应用将是非常大的推动力(如最近的压缩感知应用)。
  • 对不同损失函数(不同 GAN)的性质,目前还缺少系统性的比较研究,期待更新的研究结果。

References

  1. Ian Goodfellow et al. (2014). Generative Adversarial Networks.
  2. Ian Goodfellow. (2016). NIPS 2016 Tutorial: Generative Adversarial Networks.
  3. Nowozin et al. (2016). f-GAN: Training Generative Neural Samplers using Variational Divergence Minimization.
  4. Mao et al. (2016). Least Squares Generative Adversarial Networks.
  5. Arjovsky et al. (2016). Towards Principled Methods for Training Generative Adversarial Networks.
  6. Arjovsky et al. (2017). Wasserstein GAN.
  7. Qi. (2017). Loss-Sensitive Generative Adversarial Networks on Lipschitz Densities.
  8. An Incomplete Map of the GAN models.
  9. LS-GAN:把GAN建立在Lipschitz密度上.
  10. 广义LS-GAN(GLS-GAN) .

Further Reading

  • 2
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值