当你看到以假乱真的图片或视频,看到风格迁移的图片或视频,你应当知道,其背后的机器学习技术是GAN!
GAN, generative adversarial network, 生成式对抗神经网络, 是生成模型的一种。
生成模型主要分两种,一种由输入数据,得到概率密度分布,另外一种,由输入数据,得到与输入数据相同分布的输出数据,GAN属于第二种。更多的关于生成模型的分类,见下图。
GAN是怎样工作的呢?
GAN有两个网络,一个是生成器,希望生成同训练数据相同分布的样本,一个是判别器,希望将生成数据(fake)和训练数据(real)区分开来。
判别器希望real的output接近1,fake的output接近0,下面是判别器的损失函数的定义(只有一种):
def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
生成器的损失函数有三种(零和游戏,非饱和游戏和最大似然游戏), 在非饱和游戏中,生成器希望fake经过判别器判别的output接近1,非饱和游戏中生成器的损失函数的定义如下:
def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output), fake_output)
下图中的D表示判别器(函数),G表示生成器(函数):
需要注意,生成器的损失函数既依赖于生成器神经网络的参数,也依赖于判别器神经网络的参数;同样判别器的损失函数既依赖于判别器神经网络的参数,也依赖于生成器神经网络的参数。
训练GAN是一个博弈的过程,需要找到纳什均衡。
以上图片来自于 NIPS 2016 Tutorial: Generative Adversarial Networks by Ian Goodfellow
用来表示损失函数的定义的示例代码来自于 tensorflow tutorials https://www.tensorflow.org/tutorials/generative/dcgan
祖国翔,于上海