参考代码:https://github.com/carpedm20/DCGAN-tensorflow
一、手写数字测试
1)下载数据集:
2)放到./data目录,修改model下代码:
def load_mnist(self):
data_dir = os.path.join(".\data", self.dataset_name)
fd = open(os.path.join(data_dir,'train-images.idx3-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
trX = loaded[16:].reshape((60000,28,28,1)).astype(np.float)
fd = open(os.path.join(data_dir,'train-labels.idx1-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
trY = loaded[8:].reshape((60000)).astype(np.float)
fd = open(os.path.join(data_dir,'t10k-images.idx3-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
teX = loaded[16:].reshape((10000,28,28,1)).astype(np.float)
fd = open(os.path.join(data_dir,'t10k-labels.idx1-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
teY = loaded[8:].reshape((10000)).astype(np.float)
python main.py --dataset mnist --input_height=28 --output_height=28 --train
3)结果:
4)生成图片:
二、训练数据
1)准备数据放入./data
2)开始训练:
python main.py --input_height 96 --input_width 96 --output_height 48 --output_width 48 --dataset anime --crop -–train --epoch 2 --input_fname_pattern "*.jpg"
值训练了2次