【代码篇】【1】详解GAN代码生成mnist图片(keras)
0.GAN的基本概念
GAN (Generative Adversarial Networks)从其名字可以看出,是一种生成式的对抗网络。再具体一点, 就是通过对抗的方式,去学习数据分布的生成式模型。所谓的对抗,指的是生成网络和判别网络的互相对抗。生成网络尽可能生成逼真的样本,判别网络则尽可能去判别该样本是真实样本,还是生成的假样本。
隐变量z (通常为服从高斯分布的随机噪声)通过Generator生成X fake,判别器负责判别输入的data是生成的样本X fake还是真实样本Xreal。优化的目标函数如下:
对于判别器D来说,这是一个二分类问题,V(D, G)为二分类问题中常见的交叉熵损失。对于生成器G来说,为了尽可能欺骗D,所以需要最大化生成样本的判别概率D(G(z)),即最小化log(1 - D(G())),log(D(x))- -项与生成器G无关,可以忽略。
1.如何训练GAN?
实际训练时,生成器和判别器采取交替训练,即先训练D,然后训练G,不断往复。值得注意的是,对于生成器,其最小化的是max V(D, G),即最小化V(D, G)的最大值。为了保证V (D, G)取得最大值,所以我们通常会训练迭代k次判别器,然后再迭代1次生成器(不过在实践当中发现,k通常取1即可)。当生成器G固定时,我们可以对V(D, G)求导,求出最优判别器D* (x):
把最优判别器代入上述目标函数,可以进一步求出在最优判别器下,生成器的目
标函数等价于优化Pdata (x), Pg(x )的JS散度(JSD, Jenson Shannon Divergence)。可以证明,当G, D二者的capacity足够时,模型会收敛,二者将达到纳什均衡。此时,Pdata(x)= Pg(x),判别器不论是对于pdata (x)还是pg(x)中采样的样本,其预测概率均为0.5,即生成样本与真实样本达到了难以区分的地步。
3.GAN的常见模型DCGAN
DCGAN的全称是Deep Convolutional Generative Adversarial Networks ,意即深度卷积对抗生成网络。它是由Alec Radford在论文Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks中提出的。实际上它是在GAN的基础上增加深度卷积网络结构。
DCGAN提出使用CNN结构来稳定GAN的训练,并使用了以下一些trick:
4.基于keras的DCGAN
代码主要是参考Bubbliiiing同学的代码,稍微的修改了网络结构,一共分为5个模块讲解:
Keras搭建DCGAN利用深度卷积神经网络实现图片生成
4.1生成器generator
生成器的输入为基于正态分布的N维向量noise,输出为28x28x1的mnist图片。
首先输入noise,全连接到7x7x16大小( latent_dim —> 7x7x16)。
然后使用多次上采样+卷积+batchnorm+relu模块,直到28x28x1大小。
值得注意的是,最后一层使用tanh激活函数,效果要好些。
model.summary()为打印模型。
def make_generator(self):
#----------------------------#
# make generator #
#----------------------------#
model = Sequential()
# latent_dim ---> 7x7x16
model.add(Dense(7*7*16, activation='relu', input_dim=self.latent_dim))
model.add(Reshape((7,7,16)))
# 7x7x16 ---> 7x7x32
model.add(Conv2D(32, kernel_size=3, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
# 7x7x32 ---> 14x14x64
model.add(UpSampling2D())
model.add(Conv2D(64, kernel_size=3, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
# 7x7x64 ---> 28x28x128
model.add(UpSampling2D())
model.add(Conv2D(128, kernel_size=3, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
# 28x28x128 ---> 28x28x32
model.add(