优化版本对生成对抗网络生成手写数字集(附代码详解)

# 先导入必要的库
import os
import cv2
import tensorflow as tf
import numpy as np
# 把结果保存到本地的一个库
import pickle
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

# 读取mnist数据集
mnist = input_data.read_data_sets("MNIST_DATA")
# 图片的大小为28x28, 即784
img_size = 784

# 输入的噪声,也可以设置成别的值
noise_size = 100

# 生成网络的隐藏层神经元个数
g_units = 128
# 判别网络的隐藏层神经元个数
d_units = 128
# 学习率
learning_rate = 0.001
# 每个batch的大小
batch_size = 64
# 迭代的轮数, 这里每个epoch会遍历一次训练的数据集
epochs = 300
# 对生成的图片采样保存
n_sample = 25
samples = []


# 获取输入的函数
def get_input(real_size, noise_size):
    """
    :param real_size: 真实图片的大小
    :param noise_size: 噪声的长度
    :return: 返回两个占位符,其实就是判别网络和生成网络的输入
    """
    real_img = tf.placeholder(tf.float32, [None, real_size])
    noise_img = tf.placeholder(tf.float32, [None, noise_size])

    return real_img, noise_img

# 生成器共有两层结构,noise_img---->n_units----->out_dim
def get_generator(noise, n_units, out_dim=img_size, reuse=False, alpha=0.01):
    """
    实现生成网络
    :param noise: 生成网络的输入
    :param n_units: 生成网络的隐藏层神经元个数
    :param out_dim: 生成网络的输出 [None, 784]
    :param reuse: 是否重复使用网络的各种参数
    :param alpha: LeakRelu的参数
    :return: 生成模型未经激活的输出,和tanh激活之后的输出
    """
    # 创建一个命名空间, 名称为generator
    with tf.variable_scope("generator", reuse=reuse):
        # 第一层隐藏层
        hidden1 = tf.layers.dense(noise, n_units)
        # 激活函数和dropout
        hidden1 = tf.maximum(alpha * hidden1, hidden1)
        hidden1 = tf.layers.dropout(hidden1, rate=0.2)
        # 网络未经过激活函数之前输出的结果
        logits = tf.layers.dense(hidden1, out_dim)
        out_puts = tf.tanh(logits)
        return logits, out_puts
# 判别器的结构: img---->n_units---1
def get_discriminator(img, n_units, reuse=False, alpha=0.01):
    """
    :param img: 输入图像的大小
    :param n_units: 判别网络隐藏层神经元的个数
    :param reuse: 是否重用模型的参数
    :param alpha: LeakRelu的参数
    :return: 判别模型未经激活的输出,和tanh激活之后的输出
    """

    # 创建一个命名空间, 名称为discriminator
    with tf.variable_scope("discriminator", reuse=reuse):
        # 第一层结构
        hidden1 = tf.layers.dense(img, n_units)
        # 激活函数
        hidden1 = tf.maximum(alpha * hidden1, hidden1)
        # 输出层
        logits = tf.layers.dense(hidden1, 1)
        # 使用sigmoid激活函数
        outputs = tf.sigmoid(logits)

        return logits, outputs

# tf.reset_default_graph函数用于清除默认图形堆栈并重置全局默认图形
tf.reset_default_graph()
# 接受两个placeholder
real_img, noise_img = get_input(img_size, noise_size)
# 调用生成网络
g_logits, g_outputs = get_generator(noise_img, g_units)

# 判别网络对真实图片的判别结果
d_logits_real, d_outputs_real = get_discriminator(real_img, d_units)
# 判别网络对生成图片的判别结果, resue表示使用和上面相同的结构和参数
d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, d_units, reuse=True)


# 计算损失,判别网络对真实图片的损失,tf.ones_like(x), 会生成形状如x, 数值为1的向量, tf.zeros_like(x) 同理
# 从判别器的角度我们希望判别网络能把真实的图片预测为1,把生成的图片预测为零
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
                                                                     labels=tf.ones_like(d_logits_real)))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                                     labels=tf.zeros_like(d_logits_fake)))
# 将两部分的损失合起来就是判别网络的损失值了
d_loss = tf.add(d_loss_real, d_loss_fake)


# 从生成器的角度来看,我们又希望生成器能生成接近真实的图片,也就是让判别器尽可能的把生成的图片也预测为1,这就是对抗的思想
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                                labels=tf.ones_like(d_logits_fake)))

# 优化过程, 通过命名空间的特性取出生成器和判别器中的所有的参数,以便于优化其损失
train_vars = tf.trainable_variables()
g_vars = [var for var in train_vars if var.name.startswith("generator")]
d_vars = [var for var in train_vars if var.name.startswith("discriminator")]

# 优化操作,使用Adam函数进行优化,注意后面的变量列表要与正在优化的损失对应
d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)

# 用于保存生成的图片,便于直观的看到生成模型的效果
if not os.path.exists('gen_pictures/'):
    os.makedirs('gen_pictures/')

# 保存模型,只需要保存生成模型即可
saver = tf.train.Saver(var_list=g_vars)

# 打开一个会话, 开始训练过程
with tf.Session() as sess:
    # 初始化所有的变量
    sess.run(tf.global_variables_initializer())

    for epoch in range(epochs):
        # 每个epoch会把训练样本过一遍
        for batch_i in range(mnist.train.num_examples // batch_size):
            # 从真实样本数据中取出一个batch, 表示真实的图片
            batch = mnist.train.next_batch(batch_size)
            batch_image = batch[0].reshape((batch_size, 784))

            # 同样的构造一个batch_size 的噪声
            batch_noise = np.random.uniform(-1, 1, size=(batch_size, noise_size))

            #开始进行迭代训练
            sess.run(d_train_opt, feed_dict={real_img: batch_image, noise_img: batch_noise})
            sess.run(g_train_opt, feed_dict={noise_img: batch_noise})

        # 打印每个epoch的生成网络的损失,和判别网络的损失
        train_loss_d = sess.run(d_loss,
                                feed_dict={real_img: batch_image, noise_img: batch_noise})
        train_loss_g = sess.run(g_loss, feed_dict={noise_img: batch_noise})

        print("Iterations " + str(epoch) + ", the discrimator loss is: %.4f, generator loss is: %.4f" %(train_loss_d, train_loss_g))



        # 使用更新后的参数生成一定数量的样本
        sample_noise = np.random.uniform(-1, 1, size=(n_sample, noise_size))
        _, gen_samples = sess.run(get_generator(noise_img, g_units, img_size, reuse=True),
                               feed_dict={noise_img: sample_noise})

        # 从生成的图片中随机的选取一张保存下来
        single_picture = gen_samples[np.random.randint(0, n_sample)]
        # 生成图片的激活函数是tanh(-1, 1) --->(0, 2) ---->(0, 255)
        single_picture = (np.reshape(single_picture, (28, 28)) + 1) * 177.5
        # 保存图片
        cv2.imwrite("gen_pictures/A{}.jpg".format(str(epoch)), single_picture)

        samples.append(gen_samples)

        # 保存模型
        saver.save(sess, "./checkpoints/generator.ckpt")

# 将生成的图片结果写入文件
with open("train_samples.pkl", "wb") as f:
    pickle.dump(samples, f)
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值