生死看淡,不服就GAN(四)---- 用全连层GAN生成MNIST手写体

搭建全连接GAN网络

#*************************************** 生死看淡,不服就GAN **************************************************************
"""
PROJECT:MNIST_GAN_MLP
Author:Ephemeroptera
Date:2018-4-24
QQ:605686962
Reference:' improved_wgan_training-master': <https://github.com/igul222/improved_wgan_training>
           'Zardinality/WGAN-tensorflow':<https://github.com/Zardinality/WGAN-tensorflow>
           'NELSONZHAO/zhihu':<https://github.com/NELSONZHAO/zhihu>
"""

# import dependency
import tensorflow as tf
import numpy as np
import pickle
import visualization
import os
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from threading import Thread
from time import sleep
import time
import cv2

# import MNIST dataset
mnist_dir = r'../mnist_dataset'
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(mnist_dir)

#------------------------------------------------ define moudle related -------------------------------------------------

# define generator
def Generator_MLP(latents,out_dim,reuse=False):
    uints = 128

    with tf.variable_scope("generator", reuse=reuse):
        # dense0
        dense0 = tf.layers.dense(latents,uints,activation=tf.nn.leaky_relu,name='dense0')
        # dropout
        dropout = tf.layers.dropout(dense0,rate=0.2,name='dropout')
        # dense1
        logits = tf.layers.dense(dropout, out_dim,name='dense1')
        # output
        outputs = tf.tanh(logits,name='outputs')

        return logits,outputs

# define discriminator
def Discriminator_MLP(input,out_dim,reuse=False):
    uints = 128

    with tf.variable_scope("discriminator", reuse=reuse):
        # dense0
        dense0 = tf.layers.dense(input, uints, activation=tf.nn.leaky_relu, name='dense0',
                                 kernel_initializer=tf.random_normal_initializer(0,0.1))
        # dense1
        logits = tf.layers.dense(dense0, out_dim, name='dense1',
                                 kernel_initializer=tf.random_normal_initializer(0,0.1))
        # output
        outputs = tf.sigmoid(logits, name='outputs')

        return logits, outputs

# counting total to vars
def COUNT_VARS(vars):
    total_para = 0
    for variable in vars:
        # get each shape of vars
        shape = variable.get_shape()
        variable_para = 1
        for dim in shape:
            variable_para *= dim.value
        total_para += variable_para
    return total_para

# display paras infomation
def ShowParasList(paras):
    p = open('./trainLog/Paras.txt', 'w')
    p.writelines(['vars_total: %d'%COUNT_VARS(paras),'\n'])
    for variable in paras:
        p.writelines([variable.name, str(variable.get_shape()),'\n'])
        print(variable.name, variable.get_shape())
    p.close()

# build related dirs
def GEN_DIR():
    if not os.path.isdir('ckpt'):
        print('DIR:ckpt NOT FOUND,BUILDING ON CURRENT PATH..')
        os.mkdir('ckpt')
    if not os.path.isdir('trainLog'):
        print('DIR:ckpt NOT FOUND,BUILDING ON CURRENT PATH..')
        os.mkdir('trainLog')

#---------------------------------------------- build graph -------------------------------------------------------------
# hyper-parameters
latents_dim = 128
img_dim = 28*28
smooth = 0.1
learn_rate = 0.001

# define input
latents = tf.placeholder(shape=[None,latents_dim],dtype=tf.float32,name='latents')
input_real = tf.placeholder(shape=[None,img_dim],dtype=tf.float32,name='input_real')

# get output of G,D
_, g_outputs = Generator_MLP(latents,img_dim,reuse=False)
d_logits_real, d_outputs_real = Discriminator_MLP(input_real,1,reuse=False)
d_logits_fake, d_outputs_fake = Discriminator_MLP(g_outputs,1,reuse=True)

# define loss
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
                                                                     labels=tf.ones_like(d_logits_real)) * (1 - smooth))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                                         labels=tf.zeros_like(d_logits_fake)))
d_loss = tf.add(d_loss_real, d_loss_fake)
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                                    labels=tf.ones_like(d_logits_fake)) * (1 - smooth))
# gradient descent
train_vars = tf.trainable_variables()
ShowParasList(train_vars) # display
g_vars = [var for var in train_vars if var.name.startswith("generator")]
d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
d_train_opt = tf.train.AdamOptimizer(learn_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learn_rate).minimize(g_loss, var_list=g_vars)

#------------------------------------------------ iterations --------------------------====----------------------------

GEN_DIR()
max_iters = 20000
batch_size = 64
critic_n = 5
GenLog = []
Losses = []
saver = tf.train.Saver(var_list=g_vars)

# recording training info
def SavingRecords():
    global Losses
    global GenLog
    # saving Losses
    with open('./trainLog/loss_variation.loss', 'wb') as l:
        losses = np.array(Losses)
        pickle.dump(losses, l)
        print('saving Losses sucessfully!')
    # saving 生成样本
    with open('./trainLog/GenLog.log', 'wb') as g:
        GenLog = np.array(GenLog)
        pickle.dump(GenLog, g)
        print('saving GenLog sucessfully!')

# define training
def training():
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        time_start = time.time()  # go
        for steps in range(max_iters+1):

            # 获取数据集
            data_batch = mnist.train.next_batch(batch_size)[0]
            # ops.SHOW('real',data_batch[0].reshape([28,28,1]))
            data_batch = data_batch * 2 - 1
            data_batch = data_batch.astype(np.float32)
            z = np.random.normal(0, 1, size=[batch_size, latents_dim]).astype(np.float32)

            # 训练discriminator
            for n in range(critic_n):
                sess.run(d_train_opt, feed_dict={input_real: data_batch, latents: z})
            # 训练Generator
            sess.run(g_train_opt, feed_dict={latents: z})

            # recording training_losses
            train_loss_d = sess.run(d_loss, feed_dict={input_real: data_batch, latents: z})
            train_loss_g = sess.run(g_loss, feed_dict={latents: z})
            info = [steps, train_loss_d, train_loss_g]

            # recording training_products
            gen_sanmpes = sess.run(g_outputs, feed_dict={latents: z})
            visualization.CV2_BATCH_SHOW((np.reshape(gen_sanmpes[0:9], [-1, 28, 28, 1]) + 1) / 2, 1, 3, 3, delay=1)
            print('iters::%d/%d..Discriminator_loss:%.3f..Generator_loss:%.3f..' % (steps, max_iters, train_loss_d, train_loss_g))

            if steps % 5 == 0:
                Losses.append(info)
                GenLog.append(gen_sanmpes)

            if steps % 1000 == 0 and steps > 0:
                saver.save(sess, './ckpt/generator.ckpt', global_step=steps)

            if steps == max_iters:
                # cv2.destroyAllWindows()
                # setup a thread to saving the training info
                sleep(3)
                thread1 = Thread(target=SavingRecords,args=())
                thread1.start()

            yield info

#------------------------------------------------- ANIMATION ----------------------------------------------------------
# ANIMATION
"""
note: in this code , we will see the runtime-variation of G,D losses
"""
iters = []
dloss = []
gloss = []
fig = plt.figure()
ax1 = fig.add_subplot(2,1,1,xlim=(0, max_iters), ylim=(-1, 1))
ax2 = fig.add_subplot(2,1,2,xlim=(0, max_iters), ylim=(-20, 20))
ax1.set_title('discriminator_loss')
ax2.set_title('generator_loss')
line1, = ax1.plot([], [], color='red',lw=1,label='discriminator')
line2, = ax2.plot([], [],color='blue', lw=1,label='generator')
fig.tight_layout()

def init():
    line1.set_data([], [])
    line2.set_data([], [])
    return line1,line2

def update(info):
    iters.append(info[0])
    dloss.append(info[1])
    gloss.append(info[2])
    line1.set_data(iters, dloss)
    line2.set_data(iters, gloss)
    return line1, line2

ani = FuncAnimation(fig, update, frames=training,init_func=init, blit=True,interval=1,repeat=False)
plt.show()




实验结果

1.损失函数变化曲线

2.生成日志

3.验证生成器

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Ephemeroptera

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值