搭建全连接GAN网络
"""
PROJECT:MNIST_GAN_MLP
Author:Ephemeroptera
Date:2018-4-24
QQ:605686962
Reference:' improved_wgan_training-master': <https://github.com/igul222/improved_wgan_training>
'Zardinality/WGAN-tensorflow':<https://github.com/Zardinality/WGAN-tensorflow>
'NELSONZHAO/zhihu':<https://github.com/NELSONZHAO/zhihu>
"""
import tensorflow as tf
import numpy as np
import pickle
import visualization
import os
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from threading import Thread
from time import sleep
import time
import cv2
mnist_dir = r'../mnist_dataset'
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(mnist_dir)
def Generator_MLP(latents,out_dim,reuse=False):
uints = 128
with tf.variable_scope("generator", reuse=reuse):
dense0 = tf.layers.dense(latents,uints,activation=tf.nn.leaky_relu,name='dense0')
dropout = tf.layers.dropout(dense0,rate=0.2,name='dropout')
logits = tf.layers.dense(dropout, out_dim,name='dense1')
outputs = tf.tanh(logits,name='outputs')
return logits,outputs
def Discriminator_MLP(input,out_dim,reuse=False):
uints = 128
with tf.variable_scope("discriminator", reuse=reuse):
dense0 = tf.layers.dense(input, uints, activation=tf.nn.leaky_relu, name='dense0',
kernel_initializer=tf.random_normal_initializer(0,0.1))
logits = tf.layers.dense(dense0, out_dim, name='dense1',
kernel_initializer=tf.random_normal_initializer(0,0.1))
outputs = tf.sigmoid(logits, name='outputs')
return logits, outputs
def COUNT_VARS(vars):
total_para = 0
for variable in vars:
shape = variable.get_shape()
variable_para = 1
for dim in shape:
variable_para *= dim.value
total_para += variable_para
return total_para
def ShowParasList(paras):
p = open('./trainLog/Paras.txt', 'w')
p.writelines(['vars_total: %d'%COUNT_VARS(paras),'\n'])
for variable in paras:
p.writelines([variable.name, str(variable.get_shape()),'\n'])
print(variable.name, variable.get_shape())
p.close()
def GEN_DIR():
if not os.path.isdir('ckpt'):
print('DIR:ckpt NOT FOUND,BUILDING ON CURRENT PATH..')
os.mkdir('ckpt')
if not os.path.isdir('trainLog'):
print('DIR:ckpt NOT FOUND,BUILDING ON CURRENT PATH..')
os.mkdir('trainLog')
latents_dim = 128
img_dim = 28*28
smooth = 0.1
learn_rate = 0.001
latents = tf.placeholder(shape=[None,latents_dim],dtype=tf.float32,name='latents')
input_real = tf.placeholder(shape=[None,img_dim],dtype=tf.float32,name='input_real')
_, g_outputs = Generator_MLP(latents,img_dim,reuse=False)
d_logits_real, d_outputs_real = Discriminator_MLP(input_real,1,reuse=False)
d_logits_fake, d_outputs_fake = Discriminator_MLP(g_outputs,1,reuse=True)
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
labels=tf.ones_like(d_logits_real)) * (1 - smooth))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
labels=tf.zeros_like(d_logits_fake)))
d_loss = tf.add(d_loss_real, d_loss_fake)
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
labels=tf.ones_like(d_logits_fake)) * (1 - smooth))
train_vars = tf.trainable_variables()
ShowParasList(train_vars)
g_vars = [var for var in train_vars if var.name.startswith("generator")]
d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
d_train_opt = tf.train.AdamOptimizer(learn_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learn_rate).minimize(g_loss, var_list=g_vars)
GEN_DIR()
max_iters = 20000
batch_size = 64
critic_n = 5
GenLog = []
Losses = []
saver = tf.train.Saver(var_list=g_vars)
def SavingRecords():
global Losses
global GenLog
with open('./trainLog/loss_variation.loss', 'wb') as l:
losses = np.array(Losses)
pickle.dump(losses, l)
print('saving Losses sucessfully!')
with open('./trainLog/GenLog.log', 'wb') as g:
GenLog = np.array(GenLog)
pickle.dump(GenLog, g)
print('saving GenLog sucessfully!')
def training():
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
time_start = time.time()
for steps in range(max_iters+1):
data_batch = mnist.train.next_batch(batch_size)[0]
data_batch = data_batch * 2 - 1
data_batch = data_batch.astype(np.float32)
z = np.random.normal(0, 1, size=[batch_size, latents_dim]).astype(np.float32)
for n in range(critic_n):
sess.run(d_train_opt, feed_dict={input_real: data_batch, latents: z})
sess.run(g_train_opt, feed_dict={latents: z})
train_loss_d = sess.run(d_loss, feed_dict={input_real: data_batch, latents: z})
train_loss_g = sess.run(g_loss, feed_dict={latents: z})
info = [steps, train_loss_d, train_loss_g]
gen_sanmpes = sess.run(g_outputs, feed_dict={latents: z})
visualization.CV2_BATCH_SHOW((np.reshape(gen_sanmpes[0:9], [-1, 28, 28, 1]) + 1) / 2, 1, 3, 3, delay=1)
print('iters::%d/%d..Discriminator_loss:%.3f..Generator_loss:%.3f..' % (steps, max_iters, train_loss_d, train_loss_g))
if steps % 5 == 0:
Losses.append(info)
GenLog.append(gen_sanmpes)
if steps % 1000 == 0 and steps > 0:
saver.save(sess, './ckpt/generator.ckpt', global_step=steps)
if steps == max_iters:
sleep(3)
thread1 = Thread(target=SavingRecords,args=())
thread1.start()
yield info
"""
note: in this code , we will see the runtime-variation of G,D losses
"""
iters = []
dloss = []
gloss = []
fig = plt.figure()
ax1 = fig.add_subplot(2,1,1,xlim=(0, max_iters), ylim=(-1, 1))
ax2 = fig.add_subplot(2,1,2,xlim=(0, max_iters), ylim=(-20, 20))
ax1.set_title('discriminator_loss')
ax2.set_title('generator_loss')
line1, = ax1.plot([], [], color='red',lw=1,label='discriminator')
line2, = ax2.plot([], [],color='blue', lw=1,label='generator')
fig.tight_layout()
def init():
line1.set_data([], [])
line2.set_data([], [])
return line1,line2
def update(info):
iters.append(info[0])
dloss.append(info[1])
gloss.append(info[2])
line1.set_data(iters, dloss)
line2.set_data(iters, gloss)
return line1, line2
ani = FuncAnimation(fig, update, frames=training,init_func=init, blit=True,interval=1,repeat=False)
plt.show()
实验结果
1.损失函数变化曲线
2.生成日志
3.验证生成器