与基本GAN思想类似,只不过生成器和鉴别其中所有的层均参与卷积层
基本GAN:Tensorflow2 GAN 系列(一)——基本GAN
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
(train_images,train_labels),_=keras.datasets.mnist.load_data()
train_images=2*tf.cast(train_images,tf.float32)/255.-1
train_images=tf.expand_dims(train_images,-1)
Batch_Size=256
Buffer_Size=60000 #乱序范围
dataset=tf.data.Dataset.from_tensor_slices(train_images).shuffle(Buffer_Size).batch(Batch_Size)
def generator_model():
model=tf.keras.Sequential()
model.add(layers.Dense(7*7*256,input_shape=(100,),use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Reshape([7,7,256]))
model.add(layers.Conv2DTranspose(filters=128,kernel_size=(5,5),strides=(1,1),padding='same',use_bias=False))#7*7*128
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(filters=64, kernel_size=(5, 5), strides=(2, 2), padding='same', use_bias=False))#14*14*64
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(filters=1, kernel_size=(5, 5), strides=(2, 2), padding='same',use_bias=False,activation="tanh")) # 28*28*1
return model
def discriminator_model():
model=tf.keras.Sequential()
model.add(layers.Conv2D(64,kernel_size=(5,5),strides=(2,2),padding='same',input_shape=(28,28,1)))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Conv2D(128, kernel_size=(5, 5), strides=(2, 2), padding='same'))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Conv2D(256, kernel_size=(5, 5), strides=(2, 2), padding='same'))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Flatten())
model.add(layers.Dense(1))
return model
cross_entropy=keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real_out,fake_out):
real_loss=cross_entropy(tf.ones_like(real_out),real_out)
fake_loss = cross_entropy(tf.zeros_like(fake_out), fake_out)
return real_loss+fake_loss
def generator_loss(fake_out):
fake_loss = cross_entropy(tf.ones_like(fake_out), fake_out)
return fake_loss
generator_opt=tf.keras.optimizers.Adam(0.00001)
discriminator_opt=tf.keras.optimizers.Adam(0.00001)
Epochs=100
input_dim=100
num_exp_to_generate=16
seed=tf.random.normal([num_exp_to_generate,input_dim])
generator=generator_model()
discriminator=discriminator_model()
def train_step(images):
noise=tf.random.normal([Batch_Size,input_dim])
with tf.GradientTape() as gen_tape,tf.GradientTape() as dis_tape:
real_out=discriminator(images)
gen_img =generator(noise)
fake_out=discriminator(gen_img)
dis_loss=discriminator_loss(real_out,fake_out)
gen_loss=generator_loss(fake_out)
gen_gard=gen_tape.gradient(gen_loss,generator.trainable_variables)
dis_gard = dis_tape.gradient(dis_loss, discriminator.trainable_variables)
discriminator_opt.apply_gradients(zip(dis_gard,discriminator.trainable_variables))
generator_opt.apply_gradients(zip(gen_gard, generator.trainable_variables))
def genrate_plot_image(gen_model,test_noise):
pre_images=gen_model(test_noise,training=False)
fig=plt.figure(figsize=(4,4))
for i in range(pre_images.shape[0]):
plt.subplot(4,4,i+1)
plt.imshow((pre_images[i,:,:,0]+1)/2*255.)
plt.axis('off')
plt.show()
def train(dataset,epochs):
for epoch in range(epochs):
for image_batch in dataset:
train_step(image_batch)
print(epoch)
genrate_plot_image(generator,seed)
if __name__ == '__main__':
train(dataset,200)