Tensorflow2 GAN 系列(二)——DCGAN

与基本GAN思想类似,只不过生成器和鉴别其中所有的层均参与卷积层

基本GAN:Tensorflow2 GAN 系列(一)——基本GAN

 

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt


(train_images,train_labels),_=keras.datasets.mnist.load_data()

train_images=2*tf.cast(train_images,tf.float32)/255.-1
train_images=tf.expand_dims(train_images,-1)

Batch_Size=256
Buffer_Size=60000 #乱序范围

dataset=tf.data.Dataset.from_tensor_slices(train_images).shuffle(Buffer_Size).batch(Batch_Size)

def generator_model():
    model=tf.keras.Sequential()
    model.add(layers.Dense(7*7*256,input_shape=(100,),use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape([7,7,256]))

    model.add(layers.Conv2DTranspose(filters=128,kernel_size=(5,5),strides=(1,1),padding='same',use_bias=False))#7*7*128
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(filters=64, kernel_size=(5, 5), strides=(2, 2), padding='same', use_bias=False))#14*14*64
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(filters=1, kernel_size=(5, 5), strides=(2, 2), padding='same',use_bias=False,activation="tanh"))  # 28*28*1

    return model

def discriminator_model():
    model=tf.keras.Sequential()
    model.add(layers.Conv2D(64,kernel_size=(5,5),strides=(2,2),padding='same',input_shape=(28,28,1)))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, kernel_size=(5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(256, kernel_size=(5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))
    return model

cross_entropy=keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_out,fake_out):
    real_loss=cross_entropy(tf.ones_like(real_out),real_out)
    fake_loss = cross_entropy(tf.zeros_like(fake_out), fake_out)
    return real_loss+fake_loss

def generator_loss(fake_out):
    fake_loss = cross_entropy(tf.ones_like(fake_out), fake_out)
    return fake_loss

generator_opt=tf.keras.optimizers.Adam(0.00001)
discriminator_opt=tf.keras.optimizers.Adam(0.00001)

Epochs=100
input_dim=100
num_exp_to_generate=16
seed=tf.random.normal([num_exp_to_generate,input_dim])

generator=generator_model()
discriminator=discriminator_model()

def train_step(images):
    noise=tf.random.normal([Batch_Size,input_dim])
    with tf.GradientTape() as gen_tape,tf.GradientTape() as dis_tape:
        real_out=discriminator(images)
        gen_img =generator(noise)
        fake_out=discriminator(gen_img)
        dis_loss=discriminator_loss(real_out,fake_out)
        gen_loss=generator_loss(fake_out)

    gen_gard=gen_tape.gradient(gen_loss,generator.trainable_variables)
    dis_gard = dis_tape.gradient(dis_loss, discriminator.trainable_variables)
    discriminator_opt.apply_gradients(zip(dis_gard,discriminator.trainable_variables))
    generator_opt.apply_gradients(zip(gen_gard, generator.trainable_variables))

def genrate_plot_image(gen_model,test_noise):
    pre_images=gen_model(test_noise,training=False)
    fig=plt.figure(figsize=(4,4))
    for i in range(pre_images.shape[0]):
        plt.subplot(4,4,i+1)
        plt.imshow((pre_images[i,:,:,0]+1)/2*255.)
        plt.axis('off')
    plt.show()

def train(dataset,epochs):
    for epoch in range(epochs):
        for image_batch in dataset:
            train_step(image_batch)
        print(epoch)
        genrate_plot_image(generator,seed)

if __name__ == '__main__':
    train(dataset,200)

 

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值