Tensorflow2 GAN 系列(三)——CGAN

本文还是基于手写数字生成介绍CGAN

需要先了解GAN的基本思想

Tensorflow2 GAN 系列(一)——基本GAN

 

CGAN 为 Condition GAN 的缩写,从名字就可以看出是条件生成对抗网络

与基本GAN不同的地方在于:

生成器输入为:噪声+条件

噪声就是一堆随机数,而对于手写数字生成,条件就是0-9的数字标签

判别器输入:结果或真实数据+条件

这个判别器不仅要判别是否为真实的图片,还要判别出这个图片是什么

这样当我们训练好生成器后,就可以根据传入的噪声和条件(标签),生成指定的内容

代码如下:

#条件GAN,生成和鉴别时加入条件输入,根据条件进行判断
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import numpy as np

(images,labels),_ = keras.datasets.mnist.load_data()


images = 2 * tf.cast(images,tf.float32) / 255. - 1
images = tf.expand_dims(images,-1)

dataset = tf.data.Dataset.from_tensor_slices((images,labels)).shuffle(images.shape[0]).batch(256)

noise_dim = 50

def generator_model():
    seed = layers.Input(shape=(noise_dim))#噪声
    label = layers.Input(shape=(()))#代表输入为数组

    x = layers.Embedding(10,50,input_length=1)(label)#输入长度为1,输入的种类为10,映射为50个长度的向量
    x = layers.concatenate([seed,x])#100维向量
    x = layers.Dense(3 * 3 * 128,activation='relu',use_bias=False)(x)
    x = layers.Reshape([3, 3, 128])(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = layers.Conv2DTranspose(64,(3,3),strides=(2,2),use_bias=False)(x)#7*7*64
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = layers.Conv2DTranspose(32, (3, 3), strides=(2, 2),padding='same', use_bias=False)(x)#same表示填充保持采样后尺寸不变,vaild表示不填充 14*14*32
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = layers.Conv2DTranspose(1, (3, 3), strides=(2, 2), padding='same', use_bias=False)(x)
    #输出层一般不用bn
    x = layers.Activation(activation='tanh')(x)
    model = tf.keras.Model(inputs=[seed,label],outputs=x)
    return model

def discriminator_model():
    image = layers.Input(shape=(28,28,1))
    label = layers.Input(shape=(()))  # 代表输入为数组
    x = layers.Embedding(10, 28 * 28, input_length=1)(label)  # 输入长度为1,输入的种类为28*28,映射为50个长度的向量
    x = layers.Reshape([28, 28, 1])(x)
    x = layers.concatenate([x,image])  # 28*28*2

    x = layers.Conv2D(32,(3,3),strides=(2,2),padding='same',use_bias=False)(x)#14*14*32
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)
    x = layers.Dropout(0.5)(x)

    x = layers.Conv2D(64, (3, 3), strides=(2, 2), padding='same', use_bias=False)(x)#7*7*64
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)
    x = layers.Dropout(0.5)(x)

    x = layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same', use_bias=False)(x)  # 4*4*64
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)
    x = layers.Dropout(0.5)(x)

    x = layers.Flatten()(x)
    out = layers.Dense(1)(x)

    model = tf.keras.Model(inputs=[image,label],outputs=out)
    return model

disc = discriminator_model()
gen = generator_model()
bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def disc_loss(real_out,fake_out):
    real_loss = bce(tf.ones_like(real_out),real_out)
    fake_loss = bce(tf.zeros_like(fake_out),fake_out)
    return real_loss + fake_loss

def gen_loss(fake_out):
    fake_loss = bce(tf.ones_like(fake_out),fake_out)
    return fake_loss

gen_opt = tf.keras.optimizers.Adam(1e-5)
dis_opt = tf.keras.optimizers.Adam(1e-5)

def train_step(images,labels):
    noise = tf.random.normal([labels.shape[0], noise_dim])
    with tf.GradientTape() as g_tape,tf.GradientTape() as d_tape:
        fake_img = gen((noise,labels),training=True)
        fakeout = disc((fake_img,labels),training=True)
        realout = disc((images,labels),training=True)
        d_loss = disc_loss(realout,fakeout)
        g_loss = gen_loss(fakeout)
    g_grad = g_tape.gradient(g_loss,gen.trainable_variables)
    gen_opt.apply_gradients(zip(g_grad,gen.trainable_variables))
    d_grad = d_tape.gradient(d_loss,disc.trainable_variables)
    dis_opt.apply_gradients(zip(d_grad,disc.trainable_variables))



def plot_gen_image(model,noise,label,epoch):
    gen_image = model((noise,label),training=False)
    fig = plt.figure(figsize=(10,1))
    for i in range(10):
        plt.subplot(1,10,i + 1)
        plt.imshow(tf.squeeze(gen_image[i] + 1) / 2.)
        plt.axis('off')
    plt.show()

noise = tf.random.normal([10,50])
label = tf.constant([0,1,2,3,4,5,6,7,8,9])



def main():
    for epoch in range(200):
        for images,labels in dataset:
            train_step(images,labels)
        print('Epoch:', epoch)
        if (epoch + 1) % 10 == 0:
            plot_gen_image(gen,noise,label,epoch)

if __name__ == '__main__':
    main()

 

 

  • 5
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值