本文还是基于手写数字生成介绍CGAN
需要先了解GAN的基本思想
Tensorflow2 GAN 系列(一)——基本GAN
CGAN 为 Condition GAN 的缩写,从名字就可以看出是条件生成对抗网络
与基本GAN不同的地方在于:
生成器输入为:噪声+条件
噪声就是一堆随机数,而对于手写数字生成,条件就是0-9的数字标签
判别器输入:结果或真实数据+条件
这个判别器不仅要判别是否为真实的图片,还要判别出这个图片是什么
这样当我们训练好生成器后,就可以根据传入的噪声和条件(标签),生成指定的内容
代码如下:
#条件GAN,生成和鉴别时加入条件输入,根据条件进行判断
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy as np
(images,labels),_ = keras.datasets.mnist.load_data()
images = 2 * tf.cast(images,tf.float32) / 255. - 1
images = tf.expand_dims(images,-1)
dataset = tf.data.Dataset.from_tensor_slices((images,labels)).shuffle(images.shape[0]).batch(256)
noise_dim = 50
def generator_model():
seed = layers.Input(shape=(noise_dim))#噪声
label = layers.Input(shape=(()))#代表输入为数组
x = layers.Embedding(10,50,input_length=1)(label)#输入长度为1,输入的种类为10,映射为50个长度的向量
x = layers.concatenate([seed,x])#100维向量
x = layers.Dense(3 * 3 * 128,activation='relu',use_bias=False)(x)
x = layers.Reshape([3, 3, 128])(x)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = layers.Conv2DTranspose(64,(3,3),strides=(2,2),use_bias=False)(x)#7*7*64
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = layers.Conv2DTranspose(32, (3, 3), strides=(2, 2),padding='same', use_bias=False)(x)#same表示填充保持采样后尺寸不变,vaild表示不填充 14*14*32
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = layers.Conv2DTranspose(1, (3, 3), strides=(2, 2), padding='same', use_bias=False)(x)
#输出层一般不用bn
x = layers.Activation(activation='tanh')(x)
model = tf.keras.Model(inputs=[seed,label],outputs=x)
return model
def discriminator_model():
image = layers.Input(shape=(28,28,1))
label = layers.Input(shape=(())) # 代表输入为数组
x = layers.Embedding(10, 28 * 28, input_length=1)(label) # 输入长度为1,输入的种类为28*28,映射为50个长度的向量
x = layers.Reshape([28, 28, 1])(x)
x = layers.concatenate([x,image]) # 28*28*2
x = layers.Conv2D(32,(3,3),strides=(2,2),padding='same',use_bias=False)(x)#14*14*32
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
x = layers.Dropout(0.5)(x)
x = layers.Conv2D(64, (3, 3), strides=(2, 2), padding='same', use_bias=False)(x)#7*7*64
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
x = layers.Dropout(0.5)(x)
x = layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same', use_bias=False)(x) # 4*4*64
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
x = layers.Dropout(0.5)(x)
x = layers.Flatten()(x)
out = layers.Dense(1)(x)
model = tf.keras.Model(inputs=[image,label],outputs=out)
return model
disc = discriminator_model()
gen = generator_model()
bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def disc_loss(real_out,fake_out):
real_loss = bce(tf.ones_like(real_out),real_out)
fake_loss = bce(tf.zeros_like(fake_out),fake_out)
return real_loss + fake_loss
def gen_loss(fake_out):
fake_loss = bce(tf.ones_like(fake_out),fake_out)
return fake_loss
gen_opt = tf.keras.optimizers.Adam(1e-5)
dis_opt = tf.keras.optimizers.Adam(1e-5)
def train_step(images,labels):
noise = tf.random.normal([labels.shape[0], noise_dim])
with tf.GradientTape() as g_tape,tf.GradientTape() as d_tape:
fake_img = gen((noise,labels),training=True)
fakeout = disc((fake_img,labels),training=True)
realout = disc((images,labels),training=True)
d_loss = disc_loss(realout,fakeout)
g_loss = gen_loss(fakeout)
g_grad = g_tape.gradient(g_loss,gen.trainable_variables)
gen_opt.apply_gradients(zip(g_grad,gen.trainable_variables))
d_grad = d_tape.gradient(d_loss,disc.trainable_variables)
dis_opt.apply_gradients(zip(d_grad,disc.trainable_variables))
def plot_gen_image(model,noise,label,epoch):
gen_image = model((noise,label),training=False)
fig = plt.figure(figsize=(10,1))
for i in range(10):
plt.subplot(1,10,i + 1)
plt.imshow(tf.squeeze(gen_image[i] + 1) / 2.)
plt.axis('off')
plt.show()
noise = tf.random.normal([10,50])
label = tf.constant([0,1,2,3,4,5,6,7,8,9])
def main():
for epoch in range(200):
for images,labels in dataset:
train_step(images,labels)
print('Epoch:', epoch)
if (epoch + 1) % 10 == 0:
plot_gen_image(gen,noise,label,epoch)
if __name__ == '__main__':
main()