GAN的代码附上
第一部分,数据处理,这里用的是mnist数据集
import tensorflow as tf
import numpy as np
#import datetime
import matplotlib.pyplot as plt
#如果”/tmp/data/”目录下存在mnist数据集,则加载,否则先下载后加载
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)#如果”MNIST_data/”目录下存在mnist数据集,则加载,否则先下载后加载
#one_hot:以Numpy数组的形式中存储着训练集、验证集、测试集
#MNIST 初始为784维向量 reshape可还原
sample_image = mnist.train.next_batch(1)[0]
print(sample_image.shape)
sample_image = sample_image.reshape([28,28])
plt.imshow(sample_image,cmap='Greys')#Matplotlib库中,调用imshow()函数实现热图绘制
plt.colorbar()
plt.show()
#print(sample_image.reshape)
#增加颜色类标的代码是plt.colorbar()
#参数cmap用于设置热图的Colormap
第二部分,网络模型的书写,这里主要包括两个部分,生成器和判别器。
首先是生成器
def generator(z,batch_size,z_dim,reuse_variables=None): #z=placeholder[None,Z_dimensions], batch_size=50,z_dim=10
#tf.AUTO_REUSE
with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE ) as scope:
g_w1=tf.get_variable('g_w1',[z_dim,3136],dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.02))
g_b1=tf.get_variable('g_b1',[3136],initializer=tf.truncated_normal_initializer(stddev=0.02))
g1=tf.matmul(z,g_w1)+g_b1
g1=tf.reshape(g1,[-1,56,56,1])
g1=tf.contrib.layers.batch_norm(g1,epsilon=1e-5,scope='bn1')#标准化函数
g1=tf.nn.relu(g1)
#layer two
g_w2=tf.get_variable('g_w2',[3,3,1,z_dim/2],dtype=tf.float32,initializer=tf.truncated_normal_initializer(stddev=0.02))
g_b2=tf.get_variable('g_b2',[z_dim/2],initializer=tf.truncated_normal_initializer(stddev=0.02))
g2=tf.nn.conv2d(g1,g_w2,strides=[1,2,2,1],padding='SAME')
g2