在学习 Variational Auto-Encoder 时,同时注意到了 GAN 研究的火热。但当时觉得 GAN 非常不成熟(训练不稳定,依赖各种说不清的 tricks;没有有效的监控指标,需要大量的人工判断,因此难以扩展到图像之外的高维数据)。在读了 Goodfellow 的 tutorial 后[2],开始黑转路人,觉得 GAN 虽然缺点不少,但优点也很明显。WGAN[5, 6] 等工作出现后,开始逐渐路人转粉,对 GAN 产生了兴趣。
这里,我们仅仅从直观上讨论GAN框架及相关变种,将理论留待将来讨论。
1. Basic GAN
本质上,GAN 是一种训练模式,而非一种待定的网络结构[1]。
图1. GAN基本框架【src】
GAN 的基本思想是,生成器和判别器玩一场“道高一尺,魔高一丈”的游戏:判别器要练就“火眼金睛”,尽量区分出真实的样本(如真实的图片)和由生成器生成的假样本;生成器要学着“以假乱真”,生成出使判别器判别为真实的“假样本”。
竞争的理想怦是双方都不断进步——(理想情况下)判别器的眼睛越发“雪亮”,生成器的欺骗能力也不断提高。对抗的胜负无关紧要,重要的是,最后生成器的欺骗能力足够好,能够生成与真实样本足够相似的样本——直观而言,生成的样本看起来像是训练集(如图片)的样本;形式化的,生成器生成样本的分布,应该与训练集样本分布接近。
理论上可以,在理想条件下,生成器是可以通过这种对抗得到目标分布的(即生成足够真实的样本)。
假设要训练数据为灰度 MNIST(归一化为[0, 1]之间),生成器(generator)可以为任意输入为隐变量维度,输出为 1 x 28 x 28的模型。一个示例模型定义如下:
def build_generator(latent_size):
model = Sequential()
model.add(Dense(1024, input_dim=latent_size, activation='relu'))
model.add(Dense(28 * 28, activation='tanh'))
model.add(Reshape((1, 28, 28)))
return model
判别器(discriminator)可以为任意输入 1 x 28 x 28,输出为1维且在 [0, 1] 之间(经过sigmoid激活)的模型。一个示例模型定义如下:
def build_discriminator():
model = Sequential()
model.add(Flatten(input_shape=(1, 28, 28)))
model.add(Dense(256, activation='relu'))
model.add(Dense(128, activation='relu'))
model.add(Dense(1), activation='sigmoid')
return model
输出值表示判别器判别输入样本为真的概率。即输出值越接近1,判别器越确信样本为真;输出值越接近0,判别器越确信样本为假。
判别器
L D = − Σ i log ( D ( x i ) ) − Σ i log ( 1 − D ( G ( z i ) ) ) L_D = -\Sigma_i \log(D(\textbf{x}_i)) -\Sigma_i \log(1-D(G(\textbf{z}_i))) LD=−Σilog(D(xi))−Σilog(1−D(G(zi)))
判别器的训练的目标为:对于真实样本,输出尽量接近1;对于生成器生成的假样本,输出尽量接近0。
也即训练判别器时,真实样本的标签为1,生成样本的标签为0。
生成器
L G = Σ i log ( 1 − D ( G