【GAN】使用GAN进行mnist数据集中手写图片的生成

一、GAN介绍

GAN一般有两个内容,一是生成器(generator),二是辨别器(discriminator)。

辨别器的目的是:尽可能地分辨输入的数据是生成器生成的假数据还是真实的数据

生成器的目的是:尽可能地骗过辨别器,使得辨别器认为它生成的数据是真实的数据

这是个博弈的过程,能够使得生成器和辨别器不断成长,最后生成器能够生成以假乱真的数据

其中生成器的输入是随机向量,输出是指定的数据

鉴别器的输入是数据,输出的是0到1之间的数(意味着数据是真实的数据的概率)

本博客使用的代码是在tensorflow2.0.0基础上进行的,主要使用keras

二、代码分析

1、导入tensorflow模块

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets
from tensorflow.keras.layers import Dense,LeakyReLU,BatchNormalization,Reshape,Flatten
from tensorflow.keras.losses import BinaryCrossentropy
import numpy as np
import matplotlib.pyplot as plt

tensorflow内的datasets中有mnist手写数据集

keras.layers中有能够直接使用的层

keras.losses中是损失函数

2、载入数据,并做预处理

(train_images,_),(_,_) = datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0],28,28,1).astype('float32')
train_images = (train_images-127.5)/127.5

BATCH_SIZE = 256
BUFFER_SIZE = 60000

datasets = tf.data.Dataset.from_tensor_slices(train_images)
datasets = datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

数据本来在[0,255]之间,将其归一化到[-1,1]之间,并且reshape多加一个通道维度,最后重构一个数据集

3、定义生成器模型

def generator_model():
    model = keras.Sequential()

    model.add(Dense(256,input_shape=(100,),use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Dense(512,use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Dense(784,use_bias=False,activation='tanh'))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Reshape((28,28,1)))

    return model

生成器的输入是一个随机的100维向量。

生成器模型由三个全连接层构成,最后一个是输出层,因为要输出28x28的数据,所以最后一个全连接层有784个神经元,并且经过激活函数之后,reshape成为一张图片28x28x1,tanh激活函数能够使得生成的数据在[-1,1]之间

4、定义辨别器模型

def discriminator_model():
    model = keras.Sequential()

    model.add(Flatten())

    model.add(Dense(512,use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Dense(256,use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Dense(1))

    return model

辨别器由一个平坦层、三个全连接层构成,其中最后一个全连接层只有一个神经元,目的是为了让其输出一个概率

5、定义损失函数和优化器

cross_entropy = BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_out,fake_out):
    real_loss = cross_entropy(tf.ones_like(real_out),real_out)
    fake_loss = cross_entropy(tf.zeros_like(fake_out),fake_out)
    return real_loss + fake_loss

def generator_loss(fake_out):
    return cross_entropy(tf.ones_like(fake_out),fake_out)

generator_opt = keras.optimizers.Adam(1e-4)
discriminator_opt = keras.optimizers.Adam(1e-4)

EPOCHS = 50
noise_dim = 100

num_exp_to_generate = 16

seed = tf.random.normal([num_exp_to_generate,noise_dim])
generator = generator_model()
discriminator = discriminator_model()

其中real_out的意思是向辨别器输入真实图片后,辨别器的输出,fake_out的意思是向辨别器输入假图片后,辨别器的输出

6、定义训练步骤

def train_step(images):
    noise = tf.random.normal([BATCH_SIZE,noise_dim])
    with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
        real_out = discriminator(images,training=True)
        gen_image = generator(noise,training=True)
        fake_out = discriminator(gen_image,training=True)
        gen_loss = generator_loss(fake_out)
        dis_loss = discriminator_loss(real_out,fake_out)
    gradient_gen = gen_tape.gradient(gen_loss,generator.trainable_variables)
    gradient_dis = dis_tape.gradient(dis_loss, discriminator.trainable_variables)
    generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))
    discriminator_opt.apply_gradients(zip(gradient_dis,discriminator.trainable_variables))
    return gen_loss,dis_loss

7、定义画图函数

def generate_plot_image(gen_model,test_noise,epoch):
    pre_images = gen_model(test_noise,training=False)
    fig = plt.figure(figsize=(4,4))
    for i in range(pre_images.shape[0]):
        plt.subplot(4,4,i+1)
        plt.imshow((pre_images[i,:,:,0] + 1 )/2,cmap='gray')
        plt.axis('off')
    plt.savefig('./images/image_at_epoch_{:04d}.png'.format(epoch))
    plt.close()

8、开始训练

def train(dataset,epochs):
    for epoch in range(epochs):
        for image_batch in dataset:
            gen_loss,dis_loss = train_step(image_batch)
        print('the ',epoch+1,' epochs have trained')
        print('gen_loss: ',gen_loss,'dis_loss: ',dis_loss)
        generate_plot_image(generator,seed,epoch)
    print('finished')

train(datasets,EPOCHS)

三、训练结果

训练大概个位数的epoch后就会隐约能够看见手写数字了

训练50个epoch后的训练结果如下所示

四、全部代码

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets
from tensorflow.keras.layers import Dense,LeakyReLU,BatchNormalization,Reshape,Flatten
from tensorflow.keras.losses import BinaryCrossentropy
import numpy as np
import matplotlib.pyplot as plt

(train_images,_),(_,_) = datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0],28,28,1).astype('float32')
train_images = (train_images-127.5)/127.5

BATCH_SIZE = 256
BUFFER_SIZE = 60000

datasets = tf.data.Dataset.from_tensor_slices(train_images)
datasets = datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

def generator_model():
    model = keras.Sequential()

    model.add(Dense(256,input_shape=(100,),use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Dense(512,use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Dense(784,use_bias=False,activation='tanh'))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Reshape((28,28,1)))

    return model

def discriminator_model():
    model = keras.Sequential()

    model.add(Flatten())

    model.add(Dense(512,use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Dense(256,use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Dense(1))

    return model

cross_entropy = BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_out,fake_out):
    real_loss = cross_entropy(tf.ones_like(real_out),real_out)
    fake_loss = cross_entropy(tf.zeros_like(fake_out),fake_out)
    return real_loss + fake_loss

def generator_loss(fake_out):
    return cross_entropy(tf.ones_like(fake_out),fake_out)

generator_opt = keras.optimizers.Adam(1e-4)
discriminator_opt = keras.optimizers.Adam(1e-4)

EPOCHS = 50
noise_dim = 100

num_exp_to_generate = 16

seed = tf.random.normal([num_exp_to_generate,noise_dim])
generator = generator_model()
discriminator = discriminator_model()

def train_step(images):
    noise = tf.random.normal([BATCH_SIZE,noise_dim])
    with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
        real_out = discriminator(images,training=True)
        gen_image = generator(noise,training=True)
        fake_out = discriminator(gen_image,training=True)
        gen_loss = generator_loss(fake_out)
        dis_loss = discriminator_loss(real_out,fake_out)
    gradient_gen = gen_tape.gradient(gen_loss,generator.trainable_variables)
    gradient_dis = dis_tape.gradient(dis_loss, discriminator.trainable_variables)
    generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))
    discriminator_opt.apply_gradients(zip(gradient_dis,discriminator.trainable_variables))
    return gen_loss,dis_loss

def generate_plot_image(gen_model,test_noise,epoch):
    pre_images = gen_model(test_noise,training=False)
    fig = plt.figure(figsize=(4,4))
    for i in range(pre_images.shape[0]):
        plt.subplot(4,4,i+1)
        plt.imshow((pre_images[i,:,:,0] + 1 )/2,cmap='gray')
        plt.axis('off')
    plt.savefig('./images/image_at_epoch_{:04d}.png'.format(epoch))
    plt.close()

def train(dataset,epochs):
    for epoch in range(epochs):
        for image_batch in dataset:
            gen_loss,dis_loss = train_step(image_batch)
        print('the ',epoch+1,' epochs have trained')
        print('gen_loss: ',gen_loss,'dis_loss: ',dis_loss)
        generate_plot_image(generator,seed,epoch)
    print('finished')

train(datasets,EPOCHS)

 

  • 2
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值