上一篇文章介绍了DCGANS基本原理,那么这篇文章我们就利用MNIST数据集进行DCGAN的实例。
生成对抗网络(GANs),通过两个模型通过对抗过程同时训练。一个生成器G学习创造看起来真实的图像,而判别器(D家”)学习区分真假图像 。训练过程中,生成器在生成逼真图像方面逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。当判别器不再能够区分真实图片和伪造图片时,训练过程达到平衡。
下方动画展示了当训练了 50 个epoch (全部数据集迭代50次) 时生成器所生成的一系列图片。图片从随机噪声开始,随着时间的推移越来越像手写数字。
那么我们开始例子的编写,该例子是用python3,tensorflow2.0编写,可以直接复制使用。并且我在该例子的编写过程中,添加了详细的代码注释,并且解释部分tensorflow2的API使用说明,供大家参考,如有错误望大家指正^_^。
1、导入必要的依赖库
import tensorflow as tf
#python图片操作库
import glob
import imageio
import PIL
#训练基本库
import matplotlib.pyplot as plt
import os
import sys
import numpy as np
from tensorflow.keras import layers
import time
#python交互式shell
from IPython import display
2、载入MNIST数据集并生成模型
#生成器创建
#载入mnist数据集
(train_images,train_label),(_,_) = tf.keras.datasets.mnist.load_data()
3、DCGANS并没有做其他的image的预处理操作,只讲数据进行[-1,1]的压缩以便训练
图片数据归一化的解释可以参考https://blog.csdn.net/wind82465/article/details/108711150
train_images = train_images.reshape(train_images.shape[0],28,28,1).astype('float32')#转化为float
train_images = (train_images - 127.5)/127.5 #没有做其他预处理,只讲图片数据修改到[-1,1]的区间内
4、定义buffer size和batch size
#定义buffer size 和batch szie
BUFFER_SIZE = 60000
BATCH_SIZE = 256
5、打乱数据,减少过拟合
#打乱数据 防止过拟合
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
6、创建生成器G模型
#创建生成器
def make_generator_model():
model = tf.keras.Sequential()#创建模型实例
#第一层须指定维度 #batch无限制
model.add(layers.Dense(7*7*BATCH_SIZE, use