打印损失:
迭代变化:可看到图像逐渐变得清晰。
import tensorflow as tf
import numpy as np
import pickle
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('./data/')
img = mnist.train.images[50]
def get_inputs(real_size, noise_size):
"""
真实图像tensor与噪声图像tensor
"""
real_img = tf.placeholder(tf.float32, [None, real_size], name='real_img')
noise_img = tf.placeholder(tf.float32, [None, noise_size], name='noise_img')
return real_img, noise_img
def get_generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):
"""
生成器
noise_img: 生成器的输入
n_units: 隐层单元个数
out_dim: 生成器输出tensor的size,这里应该为32*32=784
alpha: leaky ReLU系数
"""
with tf.variable_scope("generator", reuse=reuse):
# hidden layer
hidden1 = tf.layers.dense(noise_img, n_units)
# leaky ReLU
hidden1 = tf.maximum(alpha * hidden1, hidden1)
# dropout
hidden1 &#