GAN的网络需要一个判别器(Discriminator)和一个生成器(Generator),判别器需要不停判别生成器产生的图片是否为真实图片,直到最后判别不出来生成器生成的图片到底是真是假,则达到纳什均衡点。
本例中,生成器的任务是将100个大小的一维随机数组,通过Conv2DTranspose变成【batchSize,64,64,3】大小的图片。这些生成的图片最大课程生成的和真的一样,不让判别器识别出来。
判别器的任务是接受到生成器的假图片,并判断它是假的,接受到真的图片,判断它是真的。
import tensorflow as tf
from tensorflow.keras import Model,layers
#涉及到Kernel Size 和 Strides的调节问题,可以使用公式:N为输入数据的维度,S为步长,P为PADDING个数,F为卷积核大小
#Conv2DTranspose: (N-1)* S - P + F = 输出数据的维度
#Conv2D: (N - F + P)/S + 1 = 输出数据的维度
class Generator(Model):
#[b,100] = > [b,3,3,512] = > [b,64,64,3]
def __init__(self):
super(Generator,self).__init__()
# (INPUT_ONE_DIM - 1) * STRIKE - PADDING + F
self.fullDense = layers.Dense(3*3*512,activation=tf.nn.leaky_relu)
#(3-1)*2 -0 + 5 = 9
# [b,3,3,512]
self.conv1 = layers.Conv2DTranspose(512,3,3,padding='valid',activation=tf.nn.leaky_relu)
self.bn1 = layers.BatchNormalization()
self.conv2 = layers.Conv2DTranspose(256,5,3,padding='valid',activation=tf.nn.leaky_relu)
self.bn2 = layers.BatchNormalization()
self.conv3 = layers.Conv2DTranspose(128,3,2,padding='valid',activation=tf.nn.leaky_relu)
self.bn3 = layers.BatchNormalization()
self.conv4 = layers.Conv2DTranspose(64,2,1,padding='valid',activation=tf.nn.leaky_relu)
self.bn4 = layers.BatchNormalization()
self.conv5 = layers.Conv2DTranspose(3,5,1,padding='valid',activation=tf.nn.tanh) #64 = ( 60 - 1) + F
def call(self, inputs, training=None, mask=None):
x = self.fullDense(inputs)
x = tf.reshape(x,[-1,3,3,512])
x = self.conv1(x)
# print("CV1",x.shape)
x = self.bn1(x,training=training)
x = self.conv2(x)
# print("CV2",x.shape)
x = self.bn2(x,training=training)
x = self.conv3(x)
x = self.bn3(x,training=training)
# print("CV3",x.shape)
x = self.conv4(x)
x = self.bn4(x,training=training)
# print("CV4",x.shape)
logits = self.conv5(x)
# print("C5",logits.shape)
return logits
class Discriminator(Model):
# [b,64,64,3] = > [b,1]
def __init__(self):
super(Discriminator, self).__init__()
self.conv1 = layers.Conv2D(64,5,1,padding='valid',activation=tf.nn.leaky_relu) # 58 (64 - 5) / s = 58
self.bn1 = layers.BatchNormalization()
self.conv2 = layers.Conv2D(128,5,2,padding='valid',activation=tf.nn.leaky_relu) # 38
self.bn2 = layers.BatchNormalization()
self.conv3 = layers.Conv2D(256,3,1,padding='valid',activation=tf.nn.leaky_relu) # 28
self.bn3 = layers.BatchNormalization()
self.conv4 = layers.Conv2D(512,3,2,padding='valid',activation=tf.nn.leaky_relu) # 2
self.bn4 = layers.BatchNormalization()
self.flatten = layers.Flatten()
self.fullDense = layers.Dense(1)
def call(self, inputs, training=None, mask=None):
x = self.conv1(inputs)
x = self.bn1(x,training=training)
# print("CV1",x.shape)
x = self.conv2(x)
x = self.bn2(x,training=training)
# print("CV2",x.shape)
x = self.conv3(x)
x = self.bn3(x,training=training)
# print("CV3",x.shape)
x = self.conv4(x)
x = self.bn4(x,training=training)
# print("CV4",x.shape)
x = self.flatten(x)
logits = self.fullDense(x)
return logits