前言:最近在学习生成对抗网络(GAN, Generative Adversarial Networks),为了加深自己的理解,并帮助到想入门的同学,我特意写了这篇文章,教大家一步步搭建一个最简单原始的GAN网络 (Vanilla GAN)。代码后面会有详细(通俗易懂)的解释,大神请自动绕路~欢迎小白玩家围观~~ 查看本文jupyter notebook代码请点击这里。
首先,让我们简单回顾一下什么是GAN。
GAN最早由GoodFellow在2014年提出,查看原始论文请点击这里。GAN结构如图1所示,包含了一个生成器(Generator)和一个判别器 (Discriminator)。生成器的目的是生成以假乱真的图片,而判别器的目的是尽可能区分输入图片的真假。
举一个简单的例子,比如说假钞的流通。犯罪分子希望制作出逼真的假钞,可是警察的鉴定技术也在不断改良,双方互相博弈,互相提高,最终达到一种动态的平衡。讲到这里,是不是感觉很简单?
鉴于这是一个超级良心的教程~大家可以先跟着我一起,把代码实现。实现过程中,有不懂的先不要问(对,憋着),等跑完代码之后,看到酷炫的效果后,我再一步步解释为啥这么写。
好了,是时候放上代码了。来,先导入包。
from keras.datasets import mnist
from keras.layers import Dense, Dropout, Input
from keras.models import Model,Sequential
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import Adam
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from google.colab import drive
然后,读取Keras自带的mnist数据集。在这里我们给出一个读取数据的函数load_data()。
# Load the dataset
def load_data():
(x_train, y_train), (_, _) = mnist.load_data()
x_train = (x_train.astype(np.float32) - 127.5)/127.5
# Convert shape from (60000, 28, 28) to (60000, 784)
x_train = x_train.reshape(60000, 784)
return (x_train, y_train)
X_train, y_train = load_data()
print(X_train.shape, y_train.shape)
由于本文我们旨在实现最原始的GAN网络,因此用最简单MLP全连接层来构建生成器(用卷积层当然更好,在这里先不考虑)
def build_generator():
model = Sequential()
model.add(Dense(units=256, input_dim=100))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(units=512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(units=1024))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(units=784, activation='tanh'))
model.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
return model
generator = build_generator()
generator.summary()
生成器结构如下图所示:
然后建一个判别器,也是一个MLP全连接神经网络:
def build_discriminator():
model = Sequential()
model.add(Dense(units=1024 ,input_dim=784))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.3))
model.add(Dense(units=512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.3))
model.add(Dense(units=256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.3))
model.add(Dense(units=1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
return model
discriminator = build_discriminator()
discriminator.summary()
判别器结构如图所示:
然后,我们建立一个GAN网络,由discriminator和generator组成。
def build_GAN(discriminator, generator):
discriminator.trainable=False
GAN_input = Input(shape=(100,))
x = generator(GAN_input)
GAN_output= discriminator(x)
GAN = Model(inputs=GAN_input, outputs=GAN_output)
GAN.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
return GAN
GAN = build_GAN(discriminator, generator)
GAN.summary()
GAN结构如下图所示
然后我们给出绘制图像的函数,用于把generator生成的假图片画出来:
def draw_images(generator, epoch, examples=25, dim=(5,5), figsize=(10,10)):
noise= np.random.normal(loc=0, scale=1, size=[examples, 100])
generated_images = generator.predict(noise)
generated_images = generated_images.reshape(25,28,28)
plt.figure(figsize=figsize)
for i in range(generated_images.shape[0]):
plt.subplot(dim[0], dim[1], i+1)
plt.imshow(generated_images[i], interpolation='nearest', cmap='Greys')
plt.axis('off')
plt.tight_layout()
plt.savefig('Generated_images %d.png' %epoch)
OK, 最后一步,写一个train函数,来训练GAN网络。在这里我们设置最大迭代次数400,每次迭代生成128张假图片:
def train_GAN(epochs=1, batch_size=128):
#Loading the data
X_train, y_train = load_data()
# Creating GAN
generator= build_generator()
discriminator= build_discriminator()
GAN = build_GAN(discriminator, generator)
for i in range(1, epochs+1):
print("Epoch %d" %i)
for _ in tqdm(range(batch_size)):
# Generate fake images from random noiset
noise= np.random.normal(0,1, (batch_size, 100))
fake_images = generator.predict(noise)
# Select a random batch of real images from MNIST
real_images = X_train[np.random.randint(0, X_train.shape[0], batch_size)]
# Labels for fake and real images
label_fake = np.zeros(batch_size)
label_real = np.ones(batch_size)
# Concatenate fake and real images
X = np.concatenate([fake_images, real_images])
y = np.concatenate([label_fake, label_real])
# Train the discriminator
discriminator.trainable=True
discriminator.train_on_batch(X, y)
# Train the generator/chained GAN model (with frozen weights in discriminator)
discriminator.trainable=False
GAN.train_on_batch(noise, label_real)
# Draw generated images every 15 epoches
if i == 1 or i % 10 == 0:
draw_images(generator, i)
train_GAN(epochs=400, batch_size=128)
我用了Google colab自带的GPU,训练400代大约用了十多分钟。如果用jupyter notebook在本机跑,会慢一些 (据说2分钟一代?)。
生成的图片如下图所示
大功告成,接下来我将一步步解释train_GAN()函数是怎么工作的。
首先,导入数据集,这个容易理解。
#Loading the data
X_train, y_train = load_data()
接下来,建立一个GAN网络,GAN由两个神经网络(generator, discriminator)连接而成。
# Creating GAN
generator= build_generator()
discriminator= build_discriminator()
GAN = build_GAN(discriminator, generator)
然后,建立一个循环(400次迭代)。tqdm用来动态显示每次迭代的进度。
for i in range(1, epochs+1):
print("Epoch %d" %i)
for _ in tqdm(range(batch_size)):
接着,我们生成呈高斯分布的噪声,利用generator,来生成batch_size(128张)图片。每张图片的输入就是一个1*100的噪声矩阵。
# Generate fake images from random noiset
noise= np.random.normal(0,1, (batch_size, 100))
fake_images = generator.predict(noise)
同样的,我们从Mnist数据集中随机挑选128张真实图片。我们给真实图片标注1,给假图片标注0,然后将256张真假图片混合在一起。
# Select a random batch of real images from MNIST
real_images = X_train[np.random.randint(0, X_train.shape[0], batch_size)]
# Labels for fake and real images
label_fake = np.zeros(batch_size)
label_real = np.ones(batch_size)
# Concatenate fake and real images
X = np.concatenate([fake_images, real_images])
y = np.concatenate([label_fake, label_real])
此时,我们利用上文提到的256张带标签的真假图片,训练discriminator。训练完毕后,discriminator的weights得到了更新。(打个比方,警察通过研究市面上流通的假币,在一起开会讨论,努力研发出了新一代鉴定假钞的方法)。
# Train the discriminator
discriminator.trainable=True
discriminator.train_on_batch(X, y)
然后,我们冻结住discriminator的weights,让discriminator不再变化。然后就开始训练generator (chained GAN)。在GAN的训练中,我们输入一堆噪声,期待的输出是将假图片预测为真。在这个过程中,generator继续生成假图片,送到discriminator检验,得到检验结果,如果被鉴定为假,就不断更新自己的权重(假钞贩子不断改良造假技术),直到discriminator将加图片鉴定为真图片(直到当前鉴定假钞的技术无法识别出假钞)。
# Train the generator/chained GAN model (with frozen weights in discriminator)
discriminator.trainable=False
GAN.train_on_batch(noise, label_real)
OK,此时一次迭代进行完毕。接下来是第2, 3, ...次迭代。
现在,我们总结一下每次迭代发生了什么:
- Generator利用自己最新的权重,生成了一堆假图片。
- Discrminator根据真假图片的真实label,不断训练更新自己的权重,直到可以顺利鉴别真假图片。
- 此时discriminator权重被固定,不再发生变化。generator利用最新的discrimintor,苦苦思索,不断训练自己的权重,最终使discriminator将假图片鉴定为真图片。
换成印制假钞的例子,每次迭代发生了如下几件事:
- 假钞贩子根据最新造假技术,研发出一代假钞。
- 警察反复对比新型假钞和真币的区别,成功改良假钞鉴别方法,从而顺利鉴别出市面流通钞票的真伪。
- 假钞贩子生成假钞,马上被警察鉴别出来,痛定思痛,改良技术生成新的假钞。不成想,一上街又被警察识别了出来。日复一日,终于发明了新型假钞,当前的验钞技术已经无法成功检测出这种假钞。
然后通过每次迭代,discrimintor (警察的鉴定技术)和generator (假钞制作技术) 都越来越成熟...后来达到了动态平衡。
嗯,就这样,是不是挺简单的?
今天讲的是最原始的GAN网络,GAN发展到了如今已有许多变种,如将MLP结构换成CNN,Autoencoder,以及loss function的变化等等。我在github上找到一个超级全的用keras编写的各种花式GAN网络集合,有兴趣的小伙伴直接点击这里。本文的jupyter notebook代码请直接点击下面的小卡片~
https://nbviewer.jupyter.org/github/gaonanlee/Deep-Learning-Experiments/blob/master/Vanilla%20GAN_implementation.ipynbnbviewer.jupyter.org如有理解不到位之处,欢迎批评指教。
参考文献
- https://github.com/eriklindernoren/Keras-GAN/blob/master/cgan/cgan.py
2. https://medium.com/datadriveninvestor/generative-adversarial-network-gan-using-keras-ce1c05cfdfd3
我的其他回答:
哪些 Python 库让你相见恨晚?
python如何画出漂亮的地图?
时间序列数据如何插补缺失值?
机器学习中的因果关系: 从辛普森悖论(常见的统计学谬误)谈起