《文章》介绍了使用全连接构建对抗生成神经网络GAN,本文介绍基于卷积构建对抗生成神经网络DCGAN(深度卷积对抗生成网络,Deep Convolutional Generative Adversarial Networks)。
DCGAN的特点
- 判别模型:使用带步长的卷积(strided convolutions)取代了的空间池化(spatial pooling),容许网络学习自己的空间下采样(spatial downsampling)
- 生成模型:使用微步幅卷积(fractional strided),容许它学习自己的空间上采样(spatial upsampling)
- 激活函数:LeakyRELU
- Batch Normalization 批标准化:解决因糟糕的初始化引起的训练问题,使得梯度能传播更深层次。Batch Normalization证明了生成模型初始化的重要性,避免生成模型崩溃:生成的所有样本都在一个点上(样本相同),这是训练GANs经常遇到的失败现象。
生成器:
判别器:
反卷积:
就是把卷积的前向和反向传播完全颠倒了
实现
import tensorflow as tf
import numpy as np
import pickle # 把结果保存至本地
import matplotlib.pyplot as plt
%matplotlib inline
# 导入数据
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('./data/')
# 获得数据
def get_inputs(noise_dim, image_height, image_width, image_depth):
inputs_real = tf.placeholder(tf.float32, [None, image_height, image_width, image_depth], name='inputs_real')
inputs_noise = tf.placeholder(tf.float32, [None, noise_dim], name='inputs_noise')
return inputs_real, inputs_noise
# 生成器
def get_generator(noise_img, output_dim, is_train=True, alpha=0.01):
with tf.variable_scope('generator', reuse=(not is_train)):
# 100 * 1 to 4 * 4 * 512
# 全连接层
layer1 = tf.layers.dense(noise_img, 4 * 4 * 512)
layer1 = tf.reshape(layer1, [-1, 4, 4, 512])
# batch normalization
layer1 = tf.layers.batch_normalization(layer1, training=is_train)
# leaku Relu : RELU激活函数的变形体
layer1 = tf.maximum(alpha * layer1, layer1)
# dropout
layer1 = tf.nn.dropout(layer1, keep_prob=0.8)
# 4 * 4 * 512 to 7 * 7 * 256
layer2 = tf.layers.conv2d_transpose(layer1, 256, 4, strides=1, padding='valid') # 反卷积,卷积核大小为4
layer2 = tf.layers.batch_normalization(layer2, training=is_train)
layer2 = tf.maximum(alpha * layer2, layer2)
layer2 = tf.nn.dropout(layer2, keep_prob=0.8)
# 7 * 7 256 to 14 * 14 * 128
layer3 = tf.layers.conv2d_transpose(layer2, 128, 3, strides=2, padding='same')
layer3 = tf.layers.batch_normalization(layer3, training=is_train)
layer3 = tf.maximum(alpha * layer3, layer3)
layer3 = tf.nn.dropout(layer3, keep_prob=0.8)
# 14 * 14 * 128 to 28 * 28 * 1
logits = tf.layers.conv2d_transpose(layer3, output_dim, 3, strides=2, padding='same')
# MNIST原始数据集的像素范围在0-1,这里生成图片的范围为(-1, 1)
# 因此在训练时,记住要把MNIST像素范围进行resize
outputs = tf.tanh(logits)
return outputs
# 判别器
def get_discriminator(inputs_img, reuse=False, alpha=0.01):
with tf.variable_scope('discriminator', reuse=reuse):
# 28 * 28 * 1 to 14 * 14 * 128
# 第一层不加入BN
layer1 = tf.layers.conv2d(inputs_img, 128, 3, strides=2, padding='same')
layer1 = tf.maximum(alpha * layer1, layer1)
layer1 = tf.nn.dropout(layer1, keep_prob=0.8)
# 14 * 14 * 128 to 7 * 7 * 256
layer2 = tf.layers.conv2d(layer1, 256, 3, strides=2, padding='same')
layer2 = tf.layers.batch_normalization(layer2, training=True)
layer2 = tf.maximum(alpha * layer2, layer2)
layer2 = tf.nn.dropout(layer2, keep_prob=0.8)
# 7 * 7 * 256 to 4 * 4 * 512
layer3 = tf.layers.conv2d(layer2, 512, 3, strides=2, padding='same')
layer3 = tf.layers.batch_normalization(layer3, training=True)
layer3 = tf.maximum(alpha * layer3, layer3)
layer3 = tf.nn.dropout(layer3, keep_prob=0.8)
# 4 * 4 * 512 to 4 * 4 * 512 * 1
flatten = tf.reshape(layer3, (-1, 4 * 4 * 512)) # 拉长
logits = tf.layers.dense(flatten, 1) # 全连接,生成得分值
outputs = tf.sigmoid(logits) # 二分类,得到概率值
return logits, outputs
# 目标函数
def get_loss(inputs_real, inputs_noise, image_depth, smooth=0.1):
g_outputs = get_generator(inputs_noise, image_depth, is_train=True)
d_logits_real, d_outputs_real = get_discriminator(inputs_real)
d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, reuse=True)
# 计算Loss,(1-smooth)为平滑项,使其概率不那么绝对,不为1
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.ones_like(d_outputs_fake)*(1-smooth)))
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, labels=tf.ones_like(d_outputs_real)*(1-smooth)))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.zeros_like(d_outputs_fake)))
d_loss = tf.add(d_loss_real, d_loss_fake)
return g_loss, d_loss
下述代码使用了tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)),至于为什么这样用,请参考博文《tensorflow中的batch_norm以及tf.control_dependencies和tf.GraphKeys.UPDATE_OPS的探究》
# 优化器
def get_optimizer(g_loss, d_loss, beta1=0.4, learning_rate=0.001):
train_vars = tf.trainable_variables()
g_vars = [var for var in train_vars if var.name.startswith('generator')]
d_vars = [var for var in train_vars if var.name.startswith('discriminator')]
# Optimizer
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
g_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)
d_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
return g_opt, d_opt
def plot_images(samples):
fig, axes = plt.subplots(nrows=1, ncols=25, sharex=True, sharey=True, figsize=(50,2))
for img, ax in zip(samples, axes):
ax.imshow(img.reshape((28, 28)), cmap='Greys_r')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
fig.tight_layout(pad=0) # 自动紧凑布局
def show_generator_output(sess, n_images, inputs_noise, output_dim):
cmap = 'Greys_r'
noise_shape = inputs_noise.get_shape().as_list()[-1]
# 生成噪声图片
examples_noise = np.random.uniform(-1, 1, size=[n_images, noise_shape])
samples = sess.run(get_generator(inputs_noise, output_dim, False), feed_dict={inputs_noise: examples_noise})
result = np.squeeze(samples, -1)
return result
# 训练网络
# 定义参数
batch_size = 64
noise_size = 100
epochs = 5
n_samples = 25
learning_rate = 0.001
beta1 = 0.4
def train(noise_size, data_shape, batch_size, n_samples):
# 存储loss
losses = []
steps = 0
inputs_real, inputs_noise = get_inputs(noise_size, data_shape[1], data_shape[2], data_shape[3])
g_loss, d_loss = get_loss(inputs_real, inputs_noise, data_shape[-1])
g_train_opt, d_train_opt = get_optimizer(g_loss, d_loss, beta1, learning_rate)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 迭代epoch
for e in range(epochs):
for batch_i in range(mnist.train.num_examples // batch_size):
steps += 1
batch = mnist.train.next_batch(batch_size) # 取数据
batch_images = batch[0].reshape((batch_size, data_shape[1], data_shape[2], data_shape[3]))
# scale to -1, 1
batch_images = batch_images * 2 - 1
# noise
batch_noise = np.random.uniform(-1, 1, size=(batch_size, noise_size))
# run optimizers
_ = sess.run(g_train_opt, feed_dict={inputs_real: batch_images, inputs_noise: batch_noise})
_ = sess.run(d_train_opt, feed_dict={inputs_real: batch_images, inputs_noise: batch_noise})
if steps % 101 == 0:
train_loss_d = d_loss.eval({inputs_real: batch_images, inputs_noise: batch_noise})
train_loss_g = g_loss.eval({inputs_real: batch_images, inputs_noise: batch_noise})
losses.append((train_loss_d, train_loss_g))
# 显示图片
samples =show_generator_output(sess, n_samples, inputs_noise, data_shape[-1])
plot_images(samples)
print('Epoch {}/{}...'.format(e + 1, epochs), '判别器损失:{:.4f}...'.format(train_loss_d), '生成器损失:{:.4f}...'.format(train_loss_g))
with tf.Graph().as_default():
train(noise_size, [-1, 28, 28, 1], batch_size, n_samples)
结果:
Epoch 1/5... 判别器损失:0.4123... 生成器损失:4.8521...
Epoch 1/5... 判别器损失:0.3654... 生成器损失:5.5562...
Epoch 1/5... 判别器损失:0.3656... 生成器损失:5.5423...
Epoch 1/5... 判别器损失:0.3578... 生成器损失:4.9352...
Epoch 1/5... 判别器损失:0.3774... 生成器损失:5.4627...
Epoch 1/5... 判别器损失:0.3695... 生成器损失:4.7884...
Epoch 1/5... 判别器损失:0.3949... 生成器损失:4.1014...
Epoch 1/5... 判别器损失:0.3748... 生成器损失:6.5352...
Epoch 2/5... 判别器损失:0.3742... 生成器损失:5.4267...
Epoch 2/5... 判别器损失:0.3784... 生成器损失:4.3531...
Epoch 2/5... 判别器损失:0.4851... 生成器损失:7.6329...
Epoch 2/5... 判别器损失:0.3608... 生成器损失:5.1529...
Epoch 2/5... 判别器损失:0.3732... 生成器损失:4.9315...
Epoch 2/5... 判别器损失:0.4854... 生成器损失:4.2088...
Epoch 2/5... 判别器损失:0.4971... 生成器损失:3.9229...
Epoch 2/5... 判别器损失:0.4148... 生成器损失:4.8006...
Epoch 2/5... 判别器损失:0.6493... 生成器损失:2.5460...
Epoch 3/5... 判别器损失:0.4466... 生成器损失:3.5046...
Epoch 3/5... 判别器损失:0.4009... 生成器损失:4.1398...
Epoch 3/5... 判别器损失:0.3964... 生成器损失:5.2334...
c:\python36\lib\site-packages\matplotlib\pyplot.py:514: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).
max_open_warning, RuntimeWarning)
Epoch 3/5... 判别器损失:0.4330... 生成器损失:3.5293...
Epoch 3/5... 判别器损失:0.4965... 生成器损失:3.1398...
Epoch 3/5... 判别器损失:0.5042... 生成器损失:3.7247...
Epoch 3/5... 判别器损失:0.4048... 生成器损失:3.6227...
Epoch 3/5... 判别器损失:0.5662... 生成器损失:4.0197...
Epoch 4/5... 判别器损失:0.3839... 生成器损失:4.9505...
Epoch 4/5... 判别器损失:0.5061... 生成器损失:2.5535...
Epoch 4/5... 判别器损失:0.4226... 生成器损失:3.0717...
Epoch 4/5... 判别器损失:0.4145... 生成器损失:5.4944...
Epoch 4/5... 判别器损失:0.4470... 生成器损失:3.5433...
Epoch 4/5... 判别器损失:0.3766... 生成器损失:4.9400...
Epoch 4/5... 判别器损失:0.4915... 生成器损失:3.3704...
Epoch 4/5... 判别器损失:0.4568... 生成器损失:5.1727...
Epoch 4/5... 判别器损失:0.4236... 生成器损失:3.4684...
Epoch 5/5... 判别器损失:0.5798... 生成器损失:2.2550...
Epoch 5/5... 判别器损失:0.4153... 生成器损失:3.3579...
Epoch 5/5... 判别器损失:0.4763... 生成器损失:4.3362...
Epoch 5/5... 判别器损失:0.4370... 生成器损失:2.7172...
Epoch 5/5... 判别器损失:0.4608... 生成器损失:3.1666...
Epoch 5/5... 判别器损失:0.4462... 生成器损失:2.9232...
Epoch 5/5... 判别器损失:0.4985... 生成器损失:2.3787...
Epoch 5/5... 判别器损失:0.4437... 生成器损失:3.7944...
结论
GAN实现的方法很多,博文中的两篇帖子只是其中两个方法,读者可以根据具体场景进行选择。
参考
- https://blog.csdn.net/huitailangyz/article/details/85015611 --tf.layers.batch_normalization()实现原理介绍,使用with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):的原因