导入包和数据;
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pickle
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('./MNIST_data',one_hot=True)
GAN的real和noise图像的输入:
def get_input(real_size,noise_size):
real_img = tf.placeholder(tf.float32,[None,real_size])
noise_img = tf.placeholder(tf.float32,[None,noise_size])
return real_img,noise_img
Generator
def get_generator(noise_img,u_units,out_dim,reuse=False,alpha=0.01):
with tf.variable_scope('generator',reuse=reuse):
hidden1 = tf.layers.dense(noise_img,u_units)
#leaky relu
hidden1 = tf.maximum(alpha*hidden1,hidden1)
hidden1 = tf.layers.dropout(hidden1,rate=0.8)
#输入为图像的像素大小
logits = tf.layers.dense(hidden1,out_dim)
outputs = tf.tanh(logits)
return logits,outputs
Discriminator
def get_discriminator(img,n_units,out_dim=1,reuse=False,alpha=0.01):
with tf.variable_scope('discriminator',reuse=reuse):
hidden1 = tf.layers.dense(img,n_units)
hi