根据mnist数据演示DCGAN的操作实例
导入所需要的包
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
构建数据集
使用的是mnist手写字数据集
(train_image,train_label),(c,d) = tf.keras.datasets.mnist.load_data()
#
train_image = train_image.astype('float32')
train_image = np.expand_dims(train_image,-1) # 扩展维度
print("train_iamge",train_image.shape)
#使数据的范围落在0的周围,激活函数tanh 在0的周围能发挥最好的作用; 因为图片数据的取值范围是(0-255)–>(-1,1)
train_image = (train_image-127.5)/127.5 # 归一化 范围是-1 到1
BATCH_SIZE = 256
BUFFER_SIZE = 60000
#
datasets = tf.data.Dataset.from_tensor_slices(train_image)
datasets = datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
生成器模型
def generator_model():
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(7*7*256,input_shape=(100,),use_bias=False)) # 输入为100的向量的随机数 用于生成图片
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU()) # GAN里面常用的激活方式
model.add(tf.keras.layers.Reshape((7,7,256))) # 7*7*256
model.add(tf.keras.layers.Conv2DTranspose(128,(5,5),strides=(1,1),padding='same',use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU()) # GAN里面常用的激活方式 7*7*128
model.add(tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.LeakyReLU()) # GAN里面常用的激活方式 14*14*128
# 生成器最后一次网络不适用BN
model.add(tf.keras.layers.Conv2DTranspose(1,(5,5),strides=(2,2),
padding='same',
use_bias=False,
activation='tanh')) # 28*28*1
return model
判别器模型
def discriminator_model():
model = tf.keras.Sequential()
# 判别器第一层不加BN层
model.add(tf.keras.layers.Conv2D(64,(5,5),strides=(2,2),padding='same',input_shape=(28,28,1)))
model.add(tf.keras.layers.LeakyReLU())
model.add(tf.keras.layers.Dropout(0.3)) # 为了使判别器效果不是那么的精细 加个dropout 是使gan更容易训练
model.add(tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
model.add(tf.keras.layers.LeakyReLU())
model.add(tf.keras.layers.Dropout(0.3)) # 为了使判别器效果不是那么的精细 加个dropout
model.add(tf.keras.layers.Conv2D(256, (5, 5), strides=(2, 2), padding='same'))
model.add(tf.keras.layers.LeakyReLU())
model.add(tf.keras.layers.GlobalAveragePooling2D())
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(1)) # 判别输出图像真假 输出为1维
return model
构建生成器和判别器的损失函数和优化器
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True) # 计算交叉熵
def discriminator_loss(real_out,fake_out):
real_loss = cross_entropy(tf.ones_like(real_out),real_out)
fake_loss = cross_entropy(tf.zeros_like(fake_out),fake_out)
return real_loss + fake_loss
def generator_loss(fake_out):
return cross_entropy(tf.ones_like(fake_out),fake_out)
# 两个模型 创作两个优化器
discriminator_opt = tf.keras.optimizers.Adam(1e-4)
generator_opt = tf.keras.optimizers.Adam(1e-4)
generator = generator_model()
discriminator = discriminator_model()
每个batch训练步骤
# 训练
EPOCHS = 10
noise_dim = 100 # 随机100的向量来生成数据集
num_exp_to_generate = 4 # 每个epoch生成4个样本来观察
seed = tf.random.normal([num_exp_to_generate,noise_dim]) # 生成16个随机100的向量
def train_step(images):
noise = tf.random.normal([BATCH_SIZE,noise_dim])
with tf.GradientTape() as gen_tape,tf.GradientTape() as disc_tape:
real_out = discriminator(images,training = True)
gen_iamge = generator(noise,training=True)
fake_out = discriminator(gen_iamge,training=True)
gen_loss = generator_loss(fake_out)
disc_loss = discriminator_loss(real_out,fake_out)
gradient_gen = gen_tape.gradient(gen_loss,generator.trainable_variables)
gradient_disc = disc_tape.gradient(disc_loss,discriminator.trainable_variables)
discriminator_opt.apply_gradients(zip(gradient_disc,discriminator.trainable_variables))
generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables)) # 根据梯度优化参数
每个epoch展示效果和训练
展示预测的结果
def genrate_plot_image(gen_model,test_noise):
pre_images = gen_model(test_noise,training = False)
fig = plt.figure(figsize=(2,2))
for i in range(pre_images.shape[0]):
plt.subplot(2,2,i+1)
plt.imshow((pre_images[i,:,:,0]+1)/2,cmap='gray') # 第i张图片 全部高 全部宽 0
plt.axis('off')
plt.show()
def train(dataset,epochs):
for epoch in range(epochs):
for image_batch in dataset:
train_step(image_batch)
print('.',end=' ')
genrate_plot_image(generator,seed)
train(datasets,EPOCHS )