一、前言
1、什么是GAN?
GAN主要包括了两个部分,即生成器generator与判别器discriminator。
生成器主要用来学习真实图像分布从而让自身生成的图像更加真实,以骗过判别器。判别器则需要对接收的图片进行真假判别。
在整个过程中,生成器努力地让生成的图像更加真实,而判别器则努力地去识别出图像的真假,这个过程相当于一个二人博弈,随着时间的推移,生成器和判别器在不断地进行对抗,最终两个网络达到了一个动态均衡:生成器生成的图像接近于真实图像分布,而判别器识别不出真假图像,对于给定图像的预测为真的概率基本接近0.5(相当于随机猜测类别)
二、实战
1、参数定义
class MnistModel:
def __init__(self):
# mnist测试集
self.mnist = input_data.read_data_sets('mnist/', one_hot=True)
# 图片大小
self.img_size = self.mnist.train.images[0].shape[0]
# 每步训练使用图片数量
self.batch_size = 64
# 图片分块数量
self.chunk_size = self.mnist.train.num_examples // self.batch_size
# 训练循环次数
self.epoch_size = 300
# 抽取样本数
self.sample_size = 25
# 生成器判别器隐含层数量
self.units_size = 128
# 学习率
self.learning_rate = 0.001