生成对抗网络(GAN)及其变种已经成为最近十年以来机器学习领域最为重要的思想。--2018图灵奖得主 Yann LeCun
GAN的基础知识复习:click here
GAN模型的挑战即训练优化:click here
1、模型简介及代码
本程序主要是采用最初GAN的基本原理,选择简单的二层神经网络以及MNIST数据集并基于tensorflow平台来实现,以求得对GAN的原理以及实现过程有一个更深入的理解。
程序框架及注释:
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
sess = tf.InteractiveSession()
z_dim = 100
batchs= 128
mnist = input_data.read_data_sets("/home/zhaocq/桌面/tensorflow/mnist/raw/",one_hot=True)
def weight_variable(shape,name):
initial = tf.random_normal(shape, stddev=0.01)
return tf.Variable(initial,name = name)
def bias_variable(shape,name):
initial = tf.zeros(shape) #给偏置增加小的正值用来避免死亡节点比如
#inital = tf.constant(0.1, tf.float32, shape)
return tf.Variable(initial,name = name)
#生成器随机噪声100维
z = tf.placeholder(tf.float32,shape=[None,100],name = 'z')
#鉴别器准备MNIST图像输入设置
x = tf.placeholder(tf.float32,shape=[None,784],name = 'x')
#生成器参数定义
g_w1 = weight_variable([100,128],'g_w1')
g_b1 = bias_variable([128],'g_b1')