【TensorFlow】对抗生成网络(GAN)- 根据MNIST数据集模拟生成数字手写体实战

简介

简单的说,对抗生成网络就是真和假之间的对抗。造“假”水平不断提高,以此瞒过“真”的眼睛;辩“真”能力不断提高,以此识别出“假”。两股力量不断博弈,最后达到以假乱真的目的。

  • 生成器G(造假) :生成的结果越真越好,可以以假乱真。达到瞒天过海,骗过判别器的目的
  • 判别器D(打假):具有火眼金睛,分辩真假的能力越强越好。分辨出生成和真实的
  • 损失函数:一方面要让判别器分辨能力更强,另一方面要让生成器更真

网络架构

输入层:待生成图像(噪音)和真实数据

生成网络:将噪音图像进行生成(造假)

判别网络:

  1. 判断真实图像输出结果
  2. 判断生成图像输出结果

目标函数:

  1. 对于生成网络要使得生成结果通过判别网络为真
  2. 对于判别网络要使得输入为真实图像时判别为真,输入为生成图像时判别为假

判别模型:共享一组权重参数

程序目的:优化目标函数,把损失值为尽量优化至最小

实现

import tensorflow as tf
import numpy as np
import pickle # 把生成的结果保存至本地
import matplotlib.pyplot as plt

%matplotlib inline
# 导入数据
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('/data')
# 真实数据和噪音数据
def get_inputs(real_size, noise_size):  # 原始论文中,noise_size大小为100
    real_img = tf.placeholder(tf.float32, [None, real_size])
    noise_img = tf.placeholder(tf.float32, [None, noise_size])
    
    return real_img, noise_img

生成器(G):

  • noise_img: 产生的噪音输入
  • n_units:隐层单元个数
  • out_dim:输出的大小(28 28 1)

tf.layers.dense()完成了全连接层的操作,即等价于操作:Wx+b --> 激活函数。函数的具体功能参考详细博文

def get_generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):
    with tf.variable_scope('generator', reuse=reuse):  # 参数更新,要不要重新利用参数
        # hidden layer
        hidden1 = tf.layers.dense(noise_img, n_units)  # dense:全连接层。输出结果的最后一维度就等于神经元的个数,即units的数值
        # leaku Relu : RELU激活函数的变形体
        hidden1 = tf.maximum(alpha * hidden1, hidden1) # 返回两者间的最大值
        # dropout
        hidden1 = tf.layers.dropout(hidden1, rate=0.2)
        
        # logits & outputs
        logits = tf.layers.dense(hidden1, out_dim)  # 全连接层
        outputs = tf.tanh(logits)  # 把值压缩至[-1, 1]区间
        
        return logits, outputs  # 返回得分值和输出值

判别器(D):

  • img:输入
  • n_units:隐层单元数量
  • reuse:由于要使用两次(两种不同的输入,真实值输入和生成器的输入)
def get_discriminator(img, n_units, reuse=False, alpha=0.01):
    with tf.variable_scope('discriminator', reuse=reuse):  # 参数更新,要不要重新利用参数
        # hidden layer
        hidden1 = tf.layers.dense(img, n_units)
        hidden1 = tf.maximum(alpha * hidden1, hidden1)
        
        # logits & outputs
        logits = tf.layers.dense(hidden1, 1)  # 全连接层
        outputs = tf.sigmoid(logits)  # 把值压缩至[-1, 1]区间
        
        return logits, outputs  # 返回得分值和输出值

网络参数定义:

  • img_size:输入大小
  • noise_size:噪声图像大小
  • g_units:生成器隐层参数
  • d_units:判别器隐层参数
  • learning_rate:学习率
img_size = mnist.train.images[0].shape[0]

noise_size = 100

g_units = 128

d_units = 128

learning_rate = 0.001

alpha = 0.01

构建网络:

# 构建网络
tf.reset_default_graph()

real_img, noise_img = get_inputs(img_size, noise_size)

# generator
g_logits, g_outputs = get_generator(noise_img, g_units, img_size)

# discriminator
d_logits_real, d_outputs_real = get_discriminator(real_img, d_units)  # 判别真实的图像
d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, d_units, reuse=True)  # 判别生成的图像。reuse=True使用同样一组参数

目标函数:

# discriminator的loss
# 识别真实图片
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, labels=tf.ones_like(d_logits_real)))  # labels都是1,判别真实图像时结果为1.自己设置label

# 识别生成的图片
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.zeros_like(d_logits_fake)))  # 生成图片判别为假,即0

# 总体loss
d_loss = tf.add(d_loss_real, d_loss_fake)

#generator的loss
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.ones_like(d_logits_fake)))

优化器:

# 优化器
train_vars = tf.trainable_variables()  # 查看可训练的变量

# generator
g_vars = [var for var in train_vars if var.name.startswith('generator')] # 把域下的变量拿过来
print(g_vars)
# discriminator
d_vars = [var for var in train_vars if var.name.startswith('discriminator')]
print(d_vars)

# optimizer
d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)
[<tf.Variable 'generator/dense/kernel:0' shape=(100, 128) dtype=float32_ref>, <tf.Variable 'generator/dense/bias:0' shape=(128,) dtype=float32_ref>, <tf.Variable 'generator/dense_1/kernel:0' shape=(128, 784) dtype=float32_ref>, <tf.Variable 'generator/dense_1/bias:0' shape=(784,) dtype=float32_ref>]
[<tf.Variable 'discriminator/dense/kernel:0' shape=(784, 128) dtype=float32_ref>, <tf.Variable 'discriminator/dense/bias:0' shape=(128,) dtype=float32_ref>, <tf.Variable 'discriminator/dense_1/kernel:0' shape=(128, 1) dtype=float32_ref>, <tf.Variable 'discriminator/dense_1/bias:0' shape=(1,) dtype=float32_ref>]

训练:

# 训练
# batch_size
batch_size = 64
# 训练迭代轮数
epochs = 300
# 抽取样本数
n_sample = 25

# 存储测试样例
samples = []
# 存储loss
losses = []
# 保存生成器变量
saver = tf.train.Saver(var_list=g_vars)
# 开始训练
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for e in range(epochs):
        for batch_i in range(mnist.train.num_examples // batch_size):
            batch = mnist.train.next_batch(batch_size)  # 取数据
            
            batch_images = batch[0].reshape((batch_size, 784))
            # 对图像像素进行scale,这是因为tanh输出的结果介于(-1, 1),real和fake图片共享discriminator的参数
            batch_images = batch_images * 2 - 1 # 转换至[0, 1]区间
            
            # generator的输入噪声
            batch_noise = np.random.uniform(-1, 1, size=(batch_size, noise_size))  # 在-1至1区间随机产生一个size形状的batch_noise
            
            # run optimizers
            _ = sess.run(d_train_opt, feed_dict={real_img: batch_images, noise_img: batch_noise})
            _ = sess.run(g_train_opt, feed_dict={noise_img: batch_noise})
            
        # 每一轮结束计算loss
        train_loss_d = sess.run(d_loss, feed_dict={real_img: batch_images, noise_img: batch_noise})
        
        # real img loss
        train_loss_d_real = sess.run(d_loss_real, feed_dict={real_img: batch_images, noise_img: batch_noise})
        
        # fake img loss
        train_loss_d_fake = sess.run(d_loss_fake, feed_dict={real_img: batch_images, noise_img: batch_noise})
        
        # generator loss
        train_loss_g = sess.run(g_loss, feed_dict={noise_img: batch_noise})
        
        print('Epoch {}/{}...'.format(e + 1, epochs), '判别器损失:{:.4f}(判别真实的:{:.4f} + 判别生成的:{:.4f})...'.format(train_loss_d, train_loss_d_real, train_loss_d_fake), '生成器损失:{:.4f}'.format(train_loss_g))
        
        losses.append((train_loss_d, train_loss_d_real, train_loss_d_fake, train_loss_g))
        
        # 保存样本
        sample_noise = np.random.uniform(-1, 1, size=(n_sample, noise_size))
        gen_samples = sess.run(get_generator(noise_img, g_units, img_size, reuse=True), feed_dict={noise_img: sample_noise})
        samples.append(gen_samples)
        
        saver.save(sess, './checkpoints/generator.ckpt')
        
# 保存到本地
with open('train_samples.pkl', 'wb') as f:
    pickle.dump(samples, f)  # 序列化对象,将对象samples保存到文件f中去

部分执行结果:

Epoch 1/300... 判别器损失:0.0758(判别真实的:0.0131 + 判别生成的:0.0626)... 生成器损失:3.6041
Epoch 2/300... 判别器损失:4.4196(判别真实的:1.2850 + 判别生成的:3.1346)... 生成器损失:0.6899
Epoch 3/300... 判别器损失:0.6978(判别真实的:0.4161 + 判别生成的:0.2818)... 生成器损失:3.3702
Epoch 4/300... 判别器损失:0.8266(判别真实的:0.5486 + 判别生成的:0.2780)... 生成器损失:3.7675
Epoch 5/300... 判别器损失:1.3491(判别真实的:0.8964 + 判别生成的:0.4528)... 生成器损失:1.7467
Epoch 6/300... 判别器损失:0.9448(判别真实的:0.4921 + 判别生成的:0.4527)... 生成器损失:1.6264
Epoch 7/300... 判别器损失:1.1907(判别真实的:0.6027 + 判别生成的:0.5880)... 生成器损失:1.9225
Epoch 8/300... 判别器损失:1.6996(判别真实的:0.9158 + 判别生成的:0.7838)... 生成器损失:0.8550
Epoch 9/300... 判别器损失:1.8214(判别真实的:0.6140 + 判别生成的:1.2074)... 生成器损失:1.0086
Epoch 10/300... 判别器损失:1.5922(判别真实的:0.9327 + 判别生成的:0.6595)... 生成器损失:1.7799
Epoch 11/300... 判别器损失:0.7522(判别真实的:0.5029 + 判别生成的:0.2494)... 生成器损失:2.1485
Epoch 12/300... 判别器损失:0.7268(判别真实的:0.3073 + 判别生成的:0.4196)... 生成器损失:2.0547
Epoch 13/300... 判别器损失:0.9767(判别真实的:0.4555 + 判别生成的:0.5212)... 生成器损失:1.9839
Epoch 14/300... 判别器损失:0.6412(判别真实的:0.1841 + 判别生成的:0.4571)... 生成器损失:2.0994
Epoch 15/300... 判别器损失:1.0133(判别真实的:0.6619 + 判别生成的:0.3514)... 生成器损失:1.8109
Epoch 16/300... 判别器损失:0.7774(判别真实的:0.3759 + 判别生成的:0.4015)... 生成器损失:1.9851
Epoch 17/300... 判别器损失:0.8500(判别真实的:0.5225 + 判别生成的:0.3275)... 生成器损失:1.9666
Epoch 18/300... 判别器损失:1.0968(判别真实的:0.6635 + 判别生成的:0.4334)... 生成器损失:2.0298
...
...
...
Epoch 288/300... 判别器损失:0.8609(判别真实的:0.4858 + 判别生成的:0.3751)... 生成器损失:1.8564
Epoch 289/300... 判别器损失:0.7995(判别真实的:0.3920 + 判别生成的:0.4075)... 生成器损失:1.8356
Epoch 290/300... 判别器损失:0.7399(判别真实的:0.4744 + 判别生成的:0.2655)... 生成器损失:2.1919
Epoch 291/300... 判别器损失:0.7337(判别真实的:0.3043 + 判别生成的:0.4294)... 生成器损失:1.5659
Epoch 292/300... 判别器损失:0.8654(判别真实的:0.4137 + 判别生成的:0.4517)... 生成器损失:1.6585
Epoch 293/300... 判别器损失:1.0221(判别真实的:0.5846 + 判别生成的:0.4374)... 生成器损失:1.9332
Epoch 294/300... 判别器损失:0.8926(判别真实的:0.3748 + 判别生成的:0.5178)... 生成器损失:1.5783
Epoch 295/300... 判别器损失:0.8854(判别真实的:0.4075 + 判别生成的:0.4779)... 生成器损失:1.6867
Epoch 296/300... 判别器损失:1.0185(判别真实的:0.6237 + 判别生成的:0.3948)... 生成器损失:1.9525
Epoch 297/300... 判别器损失:0.9022(判别真实的:0.4550 + 判别生成的:0.4472)... 生成器损失:1.7455
Epoch 298/300... 判别器损失:1.1332(判别真实的:0.6172 + 判别生成的:0.5160)... 生成器损失:1.7095
Epoch 299/300... 判别器损失:0.9021(判别真实的:0.4032 + 判别生成的:0.4989)... 生成器损失:1.4841
Epoch 300/300... 判别器损失:0.8167(判别真实的:0.3760 + 判别生成的:0.4408)... 生成器损失:1.7197

loss迭代曲线:

# loss迭代曲线
fig, ax = plt.subplots(figsize=(20,7))  # 生成20 * 7 大小的图
losses = np.array(losses)
plt.rcParams['font.sans-serif'] = ['SimHei'] #指定默认字体 SimHei为黑体,解决中文乱码
plt.rcParams['axes.unicode_minus'] = False #用来正常显示负号
plt.plot(losses.T[0], label='判别器总损失')
plt.plot(losses.T[1], label='判别器真实损失')
plt.plot(losses.T[2], label='判别器生成损失')
plt.plot(losses.T[3], label='生成器损失')
plt.title('对抗生成网路')
ax.set_xlabel('epoch')
plt.legend()

从下图可知,当epoch=150时,网络已趋于稳定。

生成结果:

# 生成结果
# Load samples from generator taken while training
with open('train_samples.pkl', 'rb') as f:
    samples = pickle.load(f)  # 反序列化对象,将文件中的数据解析为一个python对象

nrows,ncols:子图的行列数。
sharex, sharey:

  • 设置为 True 或者 ‘all’ 时,所有子图共享 x 轴或者 y 轴,
  • 设置为 False or ‘none’ 时,所有子图的 x,y 轴均为独立,
  • 设置为 ‘row’ 时,每一行的子图会共享 x 或者 y 轴,
  • 设置为 ‘col’ 时,每一列的子图会共享 x 或者 y 轴。

返回值:

  • fig: matplotlib.figure.Figure 对象
  • axes:子图对象( matplotlib.axes.Axes)或者是他的数组
# samples是保存的结果 epoch是第几次迭代
def view_samples(epoch, samples):
    
    fig, axes = plt.subplots(figsize=(7, 7), nrows=5, ncols=5, sharey=True, sharex=True)
    # 在用plt.subplots画多个子图中,axes.flatten()将axes由n*m的Axes组展平成1*nm的Axes组
    # zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表
    for ax, img in zip(axes.flatten(), samples[epoch][1]):  # 这里samples[epoch][1]代表生成的图像结果,而[0]代表对应的logits
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        im = ax.imshow(img.reshape((28, 28)), cmap='Greys_r')
        
    return fig, axes
_ = view_samples(-1, samples)  # 显示最终的生成结果

生成器生成的图片(造假):

展示图片生成过程:

# 显示整个生成过程图片
# 指定要查看的轮次
epoch_idx = [10, 30, 60, 90, 120, 150, 180, 210, 240, 290]
show_imgs = []
for i in epoch_idx:
    show_imgs.append(samples[i][1])
# 指定图片形状
rows, cols = 10, 25
fig, axes = plt.subplots(figsize=(30, 12), nrows=rows, ncols=cols, sharex=True, sharey=True)

idx = range(0, epochs, int(epochs / rows))

for sample, ax_row in zip(show_imgs, axes):
    for img, ax in zip(sample[::int(len(sample)/cols)], ax_row):
        ax.imshow(img.reshape((28, 28)), cmap='Greys_r')
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)

 生成新的图片:

# 生成新的图片
saver = tf.train.Saver(var_list=g_vars)
with tf.Session() as sess:
    saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))
    sample_noise = np.random.uniform(-1, 1, size=(25, noise_size))
    gen_samples = sess.run(get_generator(noise_img, g_units, img_size, reuse=True), feed_dict={noise_img: sample_noise})
    
INFO:tensorflow:Restoring parameters from checkpoints\generator.ckpt
_ = view_samples(0, [gen_samples])

生成的新图片:

结论

本文用全连接的方法简单的构造了一个GAN网络,全连接方式的网络较简单,下次尝试使用卷积的方式构建DCGAN网络。

参考

  • 3
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值