DCGAN叫做深层卷积生成对抗网络,它是在GAN的基础上把GAN的生成模型和判别模型用CNN 实现,而不是简单的多层感知机。此外,论文还对CNN 进行改进,去掉了了CNN 中的全连接层,批量归一化处理,使用了反卷积操作,以及使用了LReLu激活函数等等。参考论文:《Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks》,Github代码地址:https://github.com/carpedm20/DCGAN-tensorflow。在作者源码的基础上稍作修改,使之用于mnist数据集,收敛速度很快,代码如下:
#coding=utf-8
import tensorflow as tf
import pickle
import os
import numpy as np
from scipy.misc import imsave
import matplotlib.gridspec as gridspec
import shutil
import math
# 定义一个mnist数据集的类
class mnistReader():
def __init__(self,mnistPath,onehot=True):
self.mnistPath=mnistPath
self.onehot=onehot
self.batch_index=0
print ('read:',self.mnistPath)
fo = open(self.mnistPath, 'rb')
self.train_set,self.valid_set,self.test_set = pickle.load(fo,encoding='bytes')
fo.close()
self.data_label_train=list(zip(self.train_set[0],self.train_set[1]))
np.random.shuffle(self.data_label_train)
# 获取下一个训练集的batch
def next_train_batch(self,batch_size=100):
if self.batch_index < int(len(self.data_label_train)/batch_size):
# print ("batch_index:",self.batch_index )
datum=self.data_label_train[self.batch_index*batch_size:(self.batch_index+1)*batch_size]
self.batch_index+=1
return self._decode(datum,self.onehot)
else:
self.batch_index=0
np.random.shuffle(self.data_label_train)
datum=self.data_label_train[self.batch_index*batch_size:(self.batch_index+1)*batch_size]
self.batch_index+=1
return self._decode(datum,self.onehot)
# 获取样本标签,作为生成图片的条件
def get_sample_label(self,batch_size=64):
sample=self.train_set[1][0:batch_size]
rlabel=list()
for index in sample:
hot=np.zeros(10)
hot[int(index)]=1
rlabel.append(hot)
return rlabel
# 把label变成one-hot向量
def _decode(self,datum,onehot):
rdata=list() # batch训练数据
rlabel=list()
if onehot:
for d,l in datum:
rdata.append(np.reshape(d,[28,28,1]))
hot=np.zeros(10)
hot[int(l)]=1 # label设为10维的one-hot向量
rlabel.append(hot)
else:
for d,l in datum:
rdata.append(np.reshape(d,[28,28,1]))
rlabel.append(int(l))
return rdata,rlabel
# 批量归一化类的定义
class batch_norm(object):
def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"):
with tf.variable_scope(name):
self.epsilon = epsilon
self.momentum = momentum
self.name = name
def __call__(self, x, train=True):
return tf.contrib.layers.batch_norm(x,\
decay=self.momentum, \
updates_coll