目录
本文是针对MNIST手写数据的 ac_gan_tensorflow.py的代码解读,全文按py代码顺序依次解读,对于理解acgan的基本原理有很大的帮助,可以直接运行,但实际操作应该配合其他的网络架构或者improve technique。
代码解读
line1-6 import需要的库
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np # 定义矩阵
import matplotlib.pyplot as plt # 画图
import matplotlib.gridspec as gridspec # 定义网格
import os
line 9-18 load MNIST数据 & 定义关键参数
mnist = input_data.read_data_sets('../MNIST', one_hot=True)
mb_size = 32
X_dim = mnist.train.images.shape[1] # 图像维度 [数据量,图像维度]
y_dim = mnist.train.labels.shape[1] # Label 数据长
z_dim = 10 # 噪音维度
h_dim = 128 # 中间层神经元数
eps = 1e-8 # 定义一个很小的数 +eps 可保证不为0
lr = 1e-3 # 学习率
d_steps = 3 # 没用到
其中,读取MNIST数据时,需要将二进制的MNIST数据download完成放在"MNIST"文件夹中
line 21- 34 定义plot 函数 (可作为代码块之后自己用)
def plot(samples):
fig = plt.figure(figsize=(4, 4))
gs = gridspec.GridSpec(4, 4)
gs.update(wspace=0.05, hspace=0.05)
for i, sample in enumerate(samples):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(28, 28), cmap='Greys_r') # cmp:
return fig
line 37-40 定义xavier 初始化函数
def xavier_init(size):
in_dim = size[0]
xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
return tf.random_norma