基于Tensorflow的最基本GAN网络模型

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import glob
import os
#(1)创建输入管道
# 导入原始数据
(train_images, train_labels),(_, _) = tf.keras.datasets.mnist.load_data()
# 查看原始数据大小与数据格式
# 60000张图片,每一张图片都是28*28像素
# print(train_images.shape)
# dtype('uint8'),每一位的范围都是0-255的整数,由于图像的一个通道中rgb颜色值就是0-255不等,因此uint8就是图像的标准数字格式
# print(train_images.dtype)

#(1.1)数据预处理
# 转换数据类型
train_images = train_images.reshape(train_images.shape[0], 28,28,1)
train_images = train_images.astype('float32')

# 归一化0-255>>[-1,1]
train_images = (train_images - 127.5)/127.5

#(1.2)确定训练时的BATCH_SIZE与BUFFER_SIZE
BATCH_SIZE = 256 # 每一个batch指一次训练,batch_size表示一次训练所需的数据个数。这里一次训练需要256张图片
BUFFER_SIZE = 60000 # 目前不知道buffer是干什么的

#(1.3)将归一化后的图像转化为tf内置的一种数据形式
datasets = tf.data.Dataset.from_tensor_slices(train_images)

#(1.4)将训练模型的数据集进行打乱的操作:shuffle
datasets = datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
#(2)生成器模型
def Generator_Model():
    model = keras.Sequential() # 顺序模型
    # dense 全连接层
    # 输入:长度为100的随机数向量(自己定义)
    # 输出:长度为256的向量
    model.add(layers.Dense(256, input_shape = (100,), use_bias = False))
    model.add(layers.BatchNormalization()) # 归一化层
    model.add(layers.LeakyReLU()) # 激活层
    
    model.add(layers.Dense(512, use_bias = False))
    model.add(layers.BatchNormalization()) # 归一化层
    model.add(layers.LeakyReLU()) # 激活层
    
    model.add(layers.Dense(28*28*1, use_bias = False, activation = 'tanh'))
    model.add(layers.BatchNormalization()) # 归一化层
    
    model.add(layers.Reshape((28,28,1))) # 写为元组的形式
    
    return model
#(3)判别器模型
def Discriminator_Model():
    model = keras.Sequential()
    
    model.add(layers.Flatten()) # 将3维图像拉伸为一维图像
    
    model.add(layers.Dense(512, use_bias = False))
    model.add(layers.BatchNormalization()) # 归一化层
    model.add(layers.LeakyReLU()) # 激活层
    
    model.add(layers.Dense(256, use_bias = False))
    model.add(layers.BatchNormalization()) # 归一化层
    model.add(layers.LeakyReLU()) # 激活层
    
    model.add(layers.Dense(1)) # 输出1或者0
    
    return model
    
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits = True)

#(4)判别器的损失函数:对于真是图片,判定为1;对于生成图片,判定为0
def discriminator_loss(real_out, fake_out):
    real_loss = cross_entropy(tf.ones_like(real_out),real_out)
    fake_loss = cross_entropy(tf.zeros_like(fake_out),fake_out)
    return real_loss+fake_loss

#(5)生成器损失函数:对于生成图片,判定为1
def generator_loss(fake_out):
    fake_loss = cross_entropy(tf.ones_like(fake_out),fake_out)
    return fake_loss
#(6)创建判别器和生成器的优化器,定义参数的学习速率
generator_opt = tf.keras.optimizers.Adam(1e-4)
discriminator_opt = tf.keras.optimizers.Adam(1e-4)
EPOCHS = 100
noise_dim = 100
num_exp_to_generate = 16
seed = tf.random.normal([num_exp_to_generate, noise_dim])

# 实例化生成器与判别器
Generator = Generator_Model()
Discriminator = Discriminator_Model()
#(7)训练GAN网络
# 每一个batch
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        real_output = Discriminator(images, training = True)
        gen_image = Generator(noise, training = True)
        fake_output = Discriminator(gen_image, training = True)
        
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)
        
    #优化
    gradient_gen = gen_tape.gradient(gen_loss, Generator.trainable_variables)
    gradient_disc = disc_tape.gradient(disc_loss, Discriminator.trainable_variables)
    generator_opt.apply_gradients(zip(gradient_gen, Generator.trainable_variables))
    discriminator_opt.apply_gradients(zip(gradient_disc, Discriminator.trainable_variables))
# 可视化函数
def generator_plt_img(gen_model, test_noise):

    pre_images = gen_model(test_noise, training = False)
    fig = plt.figure(figsize=(4, 4))
    for i in range(pre_images.shape[0]):
        plt.subplot(4,4,i+1)
        plt.imshow((pre_images[i,:,:,0]+1)/2, cmap = 'gray')
        plt.axis('off')
    plt.show()
# 完整的训练模型的函数
def train(dataset, epochs):
    for epoch in range(epochs):
        for img_batch in dataset:
            train_step(img_batch)
            print('.',end='')
        generator_plt_img(Generator, seed)
# 训练模型
train(datasets, EPOCHS)

视频链接:https://www.bilibili.com/video/BV1f7411E7wU/?spm_id_from=333.999.0.0

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值