基于Tensorflow2的GAN生成对抗网络(一)

GAN的网络需要一个判别器(Discriminator)和一个生成器(Generator),判别器需要不停判别生成器产生的图片是否为真实图片,直到最后判别不出来生成器生成的图片到底是真是假,则达到纳什均衡点。

本例中,生成器的任务是将100个大小的一维随机数组,通过Conv2DTranspose变成【batchSize,64,64,3】大小的图片。这些生成的图片最大课程生成的和真的一样,不让判别器识别出来。

判别器的任务是接受到生成器的假图片,并判断它是假的,接受到真的图片,判断它是真的。

import tensorflow as tf
from tensorflow.keras import Model,layers

#涉及到Kernel Size 和 Strides的调节问题,可以使用公式:N为输入数据的维度,S为步长,P为PADDING个数,F为卷积核大小

#Conv2DTranspose: (N-1)* S - P + F = 输出数据的维度

#Conv2D: (N - F + P)/S + 1 = 输出数据的维度

class Generator(Model):

    #[b,100] = > [b,3,3,512] = > [b,64,64,3]

    def __init__(self):
        super(Generator,self).__init__()

        # (INPUT_ONE_DIM - 1) * STRIKE - PADDING + F
        self.fullDense = layers.Dense(3*3*512,activation=tf.nn.leaky_relu)

        #(3-1)*2 -0 + 5 = 9
        # [b,3,3,512]

        self.conv1 = layers.Conv2DTranspose(512,3,3,padding='valid',activation=tf.nn.leaky_relu)
        self.bn1 = layers.BatchNormalization()

        self.conv2 = layers.Conv2DTranspose(256,5,3,padding='valid',activation=tf.nn.leaky_relu)
        self.bn2 = layers.BatchNormalization()

        self.conv3 = layers.Conv2DTranspose(128,3,2,padding='valid',activation=tf.nn.leaky_relu)
        self.bn3 = layers.BatchNormalization()

        self.conv4 = layers.Conv2DTranspose(64,2,1,padding='valid',activation=tf.nn.leaky_relu)
        self.bn4 = layers.BatchNormalization()

        self.conv5 = layers.Conv2DTranspose(3,5,1,padding='valid',activation=tf.nn.tanh) #64 = ( 60 - 1) + F

    def call(self, inputs, training=None, mask=None):

        x = self.fullDense(inputs)
        x = tf.reshape(x,[-1,3,3,512])
        x = self.conv1(x)
        # print("CV1",x.shape)
        x = self.bn1(x,training=training)
        x = self.conv2(x)
        # print("CV2",x.shape)
        x = self.bn2(x,training=training)
        x = self.conv3(x)
        x = self.bn3(x,training=training)
        # print("CV3",x.shape)
        x = self.conv4(x)
        x = self.bn4(x,training=training)
        # print("CV4",x.shape)
        logits = self.conv5(x)
        # print("C5",logits.shape)


        return logits



class Discriminator(Model):

     # [b,64,64,3] = >  [b,1]

    def __init__(self):
        super(Discriminator, self).__init__()

        self.conv1 = layers.Conv2D(64,5,1,padding='valid',activation=tf.nn.leaky_relu)  # 58    (64 - 5) / s = 58
        self.bn1 = layers.BatchNormalization()

        self.conv2 = layers.Conv2D(128,5,2,padding='valid',activation=tf.nn.leaky_relu)  # 38
        self.bn2 = layers.BatchNormalization()


        self.conv3 = layers.Conv2D(256,3,1,padding='valid',activation=tf.nn.leaky_relu)  # 28
        self.bn3 = layers.BatchNormalization()

        self.conv4 = layers.Conv2D(512,3,2,padding='valid',activation=tf.nn.leaky_relu)  # 2
        self.bn4 = layers.BatchNormalization()

        self.flatten = layers.Flatten()
        self.fullDense = layers.Dense(1)


    def call(self, inputs, training=None, mask=None):
        x = self.conv1(inputs)
        x = self.bn1(x,training=training)
        # print("CV1",x.shape)
        x = self.conv2(x)
        x = self.bn2(x,training=training)
        # print("CV2",x.shape)
        x = self.conv3(x)
        x = self.bn3(x,training=training)
        # print("CV3",x.shape)
        x = self.conv4(x)
        x = self.bn4(x,training=training)
        # print("CV4",x.shape)
        x = self.flatten(x)
        logits = self.fullDense(x)

        return logits

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值