一、简介
之前介绍了生成式对抗网络(GAN),关于GAN的变种比较多,我打算将几种常见的GAN做一个总结,也算是激励自己学习,分享自己的一些看法和见解。
之前提到的GAN是最基本的模型,我们的输入是随机噪声,输出的是对应的图像,但是我们没法控制生成图像的类型。比如,我要生成一张数字0的图片,但是GAN生成的图片却是数字0-9的图片,针对这个问题,Conditional Generative Adversarial Nets被提了出来,在原有GAN的基础上,添加了类别信息以便让模型生成特定的图片。这里的条件(conditional),就是这个额外的类别信息。
二、原理
由于在GAN的生成器和判别器中都加入了额外的类别信息,模型的目标优化函数也发生了变化。
生成器的输入变为噪音变量和类别信息
, 判别器的输入为图片数据
和类别信息
, 目标函数如下:
就是在GAN的目标函数上添加了y这一类别变量,x变为了条件分布。
模型的结构图如下,
GAN的结构与这个类似,生成器部分和判别器部分是分开的两个子网络,单独进行训练。类别信息y是通过embedding层嵌入的。
具体的实现可以看看代码:
生成器:
def build_generator(self):
model = Sequential()
model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))
model.summary()
noise = Input(shape=(self.latent_dim,))
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))
model_input = multiply([noise, label_embedding])
img = model(model_input)
return Model([noise, label], img)
标签是通过嵌入层实现的,embedding层可以将类别标签转换为对应的向量表示,在此生成器中,类别有10个(0-9),对应embedding中的input_dim, 输出维度和噪音数据是相同的,之后,再利用multiply层将两者逐项做乘积,这便是生成器的输入。
判别器:
def build_discriminator(self):
model = Sequential()
model.add(Dense(512, input_dim=np.prod(self.img_shape)))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.4))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.4))
model.add(Dense(1, activation='sigmoid'))
model.summary()
img = Input(shape=self.img_shape)
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shape))(label))
flat_img = Flatten()(img)
model_input = multiply([flat_img, label_embedding])
validity = model(model_input)
return Model([img, label], validity)
判别器的输入和生成器是一样的,输出是对应的图片的类别。
训练:
训练采用的mnist数据集,训练时需要将图片数据和对应的标签输入模型。
生成器和判决器作为一个整体进行训练的时候,判别器是不训练的,这时只训练生成器;当判决器作为一个单独的模型时,判决器会得到训练。二者的训练是交替进行的。
具体的代码可以参考github
三、效果
最后跑出来的效果还是很不错的,我在台式机上跑的,用的是1050ti的显卡,训练速度还比较快,一共20000轮,大概10分钟左右跑完。
这是最后的训练效果:
可以与前一篇博客里面的内容进行比较,与原始的GAN相比,效果要好一些,但是还是不是很清晰。一方面,mnist提供的图片像素较低,另一方面,我们采用的是全连接神经网络,对于图片的处理效果并不是很好。
要生成更加清晰地图片,可以利用DCGAN,这也是我接下来要做的工作。