# 先导入必要的库
import os
import cv2
import tensorflow as tf
import numpy as np
# 把结果保存到本地的一个库
import pickle
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
# 读取mnist数据集
mnist = input_data.read_data_sets("MNIST_DATA")
# 图片的大小为28x28, 即784
img_size = 784
# 输入的噪声,也可以设置成别的值
noise_size = 100
# 生成网络的隐藏层神经元个数
g_units = 128
# 判别网络的隐藏层神经元个数
d_units = 128
# 学习率
learning_rate = 0.001
# 每个batch的大小
batch_size = 64
# 迭代的轮数, 这里每个epoch会遍历一次训练的数据集
epochs = 300
# 对生成的图片采样保存
n_sample = 25
samples = []
# 获取输入的函数
def get_input(real_size, noise_size):
"""
:param real_size: 真实图片的大小
:param noise_size: 噪声的长度
:return: 返回两个占位符,其实就是判别网络和生成网络的输入
"""
real_img = tf.placeholder(tf.float32, [None, real_size])
noise_img = tf.placeholder(tf.float32, [None, noise_size])
return real_img, noise_img
# 生成器共有两层结构,noise_img---->n_units----->out_dim
def get_generator(noise, n_units, out_dim=img_size, reuse=False, alpha=0.01):
"""
实现生成网络
:param noise: 生成网络的输入
:param n_units: 生成网络的隐藏层神经元个数
:param out_dim: 生成网络的输出 [None, 784]
:param reuse: 是否重复使用网络的各种参数
:param alpha: LeakRelu的参数
:return: 生成模型未经激活的输出,和tanh激活之后的输出
"""
# 创建一个命名空间, 名称为generator
with tf.variable_scope("generator", reuse=reuse):
# 第一层隐藏层
hidden1 = tf.layers.dense(noise, n_units)
# 激活函数和dropout
hidden1 = tf.maximum(alpha * hidden1, hidden1)
hidden1 = tf.layers.dropout(hidden1, rate=0.2)
# 网络未经过激活函数之前输出的结果
logits = tf.layers.dense(hidden1, out_dim)
out_puts = tf.tanh(logits)
return logits, out_puts
# 判别器的结构: img---->n_units---1
def get_discriminator(img, n_units, reuse=False, alpha=0.01):
"""
:param img: 输入图像的大小
:param n_units: 判别网络隐藏层神经元的个数
:param reuse: 是否重用模型的参数
:param alpha: LeakRelu的参数
:return: 判别模型未经激活的输出,和tanh激活之后的输出
"""
# 创建一个命名空间, 名称为discriminator
with tf.variable_scope("discriminator", reuse=reuse):
# 第一层结构
hidden1 = tf.layers.dense(img, n_units)
# 激活函数
hidden1 = tf.maximum(alpha * hidden1, hidden1)
# 输出层
logits = tf.layers.dense(hidden1, 1)
# 使用sigmoid激活函数
outputs = tf.sigmoid(logits)
return logits, outputs
# tf.reset_default_graph函数用于清除默认图形堆栈并重置全局默认图形
tf.reset_default_graph()
# 接受两个placeholder
real_img, noise_img = get_input(img_size, noise_size)
# 调用生成网络
g_logits, g_outputs = get_generator(noise_img, g_units)
# 判别网络对真实图片的判别结果
d_logits_real, d_outputs_real = get_discriminator(real_img, d_units)
# 判别网络对生成图片的判别结果, resue表示使用和上面相同的结构和参数
d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, d_units, reuse=True)
# 计算损失,判别网络对真实图片的损失,tf.ones_like(x), 会生成形状如x, 数值为1的向量, tf.zeros_like(x) 同理
# 从判别器的角度我们希望判别网络能把真实的图片预测为1,把生成的图片预测为零
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
labels=tf.ones_like(d_logits_real)))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
labels=tf.zeros_like(d_logits_fake)))
# 将两部分的损失合起来就是判别网络的损失值了
d_loss = tf.add(d_loss_real, d_loss_fake)
# 从生成器的角度来看,我们又希望生成器能生成接近真实的图片,也就是让判别器尽可能的把生成的图片也预测为1,这就是对抗的思想
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
labels=tf.ones_like(d_logits_fake)))
# 优化过程, 通过命名空间的特性取出生成器和判别器中的所有的参数,以便于优化其损失
train_vars = tf.trainable_variables()
g_vars = [var for var in train_vars if var.name.startswith("generator")]
d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
# 优化操作,使用Adam函数进行优化,注意后面的变量列表要与正在优化的损失对应
d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)
# 用于保存生成的图片,便于直观的看到生成模型的效果
if not os.path.exists('gen_pictures/'):
os.makedirs('gen_pictures/')
# 保存模型,只需要保存生成模型即可
saver = tf.train.Saver(var_list=g_vars)
# 打开一个会话, 开始训练过程
with tf.Session() as sess:
# 初始化所有的变量
sess.run(tf.global_variables_initializer())
for epoch in range(epochs):
# 每个epoch会把训练样本过一遍
for batch_i in range(mnist.train.num_examples // batch_size):
# 从真实样本数据中取出一个batch, 表示真实的图片
batch = mnist.train.next_batch(batch_size)
batch_image = batch[0].reshape((batch_size, 784))
# 同样的构造一个batch_size 的噪声
batch_noise = np.random.uniform(-1, 1, size=(batch_size, noise_size))
#开始进行迭代训练
sess.run(d_train_opt, feed_dict={real_img: batch_image, noise_img: batch_noise})
sess.run(g_train_opt, feed_dict={noise_img: batch_noise})
# 打印每个epoch的生成网络的损失,和判别网络的损失
train_loss_d = sess.run(d_loss,
feed_dict={real_img: batch_image, noise_img: batch_noise})
train_loss_g = sess.run(g_loss, feed_dict={noise_img: batch_noise})
print("Iterations " + str(epoch) + ", the discrimator loss is: %.4f, generator loss is: %.4f" %(train_loss_d, train_loss_g))
# 使用更新后的参数生成一定数量的样本
sample_noise = np.random.uniform(-1, 1, size=(n_sample, noise_size))
_, gen_samples = sess.run(get_generator(noise_img, g_units, img_size, reuse=True),
feed_dict={noise_img: sample_noise})
# 从生成的图片中随机的选取一张保存下来
single_picture = gen_samples[np.random.randint(0, n_sample)]
# 生成图片的激活函数是tanh(-1, 1) --->(0, 2) ---->(0, 255)
single_picture = (np.reshape(single_picture, (28, 28)) + 1) * 177.5
# 保存图片
cv2.imwrite("gen_pictures/A{}.jpg".format(str(epoch)), single_picture)
samples.append(gen_samples)
# 保存模型
saver.save(sess, "./checkpoints/generator.ckpt")
# 将生成的图片结果写入文件
with open("train_samples.pkl", "wb") as f:
pickle.dump(samples, f)
优化版本对生成对抗网络生成手写数字集(附代码详解)
最新推荐文章于 2022-12-30 20:24:06 发布