mnist数据集训练GAN模型(借鉴and学习)

学习GAN模型的简单代码实现——基于MNIST数据集

一、 环境介绍


本人使用的是tensorflow1.13(因为数据集比较小用cpu训练即可,大概30分钟左右),若想加速也可以使用GPU,环境配置tensorflow-gpu 1.13 + cuda10.0.13 + cudnn7.6.5 (建议使用conda配置,选择tensorflow-gpu版本后,会自动适配相应的cuda和cudnn)。


二、模型简介


采用最基础的全连接层网络学习使用GAN模型,没有添加卷积和池化过程,使用的激活函数为leaky ReLU以及tanh(双曲正切函数),训练过程采用了dropout,起到防止过拟合的作用。


三、代码介绍


首先看一下代码在IDE中的结构图。
在这里插入图片描述

放在前面:整个project,主要的文件是 train.py + show lines.py + show images.py + test.py

接下来顺着结构图逐一介绍:

  1. 首先 checkpoints 这是训练完成之后保存的计算图,模型中的变量对应的权重和偏置等经过训练的参数,也就是说在训练之前这个文件夹是没有的。checkpoints 的生成代码在 train.py 文件中,所以 checkpoints 文件夹并不需要你自己创建,我只是方便新手理解这个架构,故逐一介绍。

    checkpoints文件夹中各个文件介绍:
    checkpoint文件:包含最新的和所有的文件地址
    .data文件:包含训练变量的文件
    .index文件:描述variable中key和value的对应关系
    .meta文件:保存完整的网络图结构

  2. data文件夹,里面保存的是MNIST数据集,下载以及加载的代码也都在 train.py 文件中。

    由上到下介绍分别为测试集图像,测试集标签,训练集图像,训练集标签。

  3. losses.npy的生成代码在train.py文件中,目的是将训练过程中各个损失的值以矩阵的形式保存到本地,方便之后绘制损失曲线图。

  4. show images.py 显示在训练过程中采集的图像,每一个epoch采集一次,但并不是要全部输出,可以自己设定需要查看的epoch的效果,每个epoch保存了25个sample(epoch 和sample 的数值可以在 train.py 中修改)。

  5. show lines.py, 绘制损失曲线图。

  6. test.py ,加载训练好的生成器,生成图像。

  7. train.py ,GAN模型的网络结构以及训练过程。

  8. train_samples.pkl,用于保存在show images中显示的图像。

四、代码


1. train.py

import tensorflow as tf
import numpy as np
import pickle


from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('./data/')
img = mnist.train.images[50]  # img 此处为mnist的train图片数据集中的第51张图片,有784个参数


def get_inputs(real_size, noise_size):
    """
    真实图像tensor与噪声图像tensor
    """
    # placeholder占位符 [None, real_size] 表示行数不定,列数为real_size的值
    real_img = tf.placeholder(tf.float32, [None, real_size], name='real_img')
    noise_img = tf.placeholder(tf.float32, [None, noise_size], name='noise_img')

    return real_img, noise_img


def get_generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):
    # 设置reuse=False时,函数get_variable()表示创建新变量,reuse=True时,表示调用已有的变量
    """
    生成器
    noise_img: 生成器的输入
    n_units: 隐层单元个数
    out_dim: 生成器输出tensor的size,这里应该为28*28=784
    alpha: leaky ReLU系数
    """
    # 创建generator变量空间,reuse = False
    with tf.variable_scope("generator", reuse=reuse):  # tf.variable_scope()用来指定变量的作用域,作为变量名的前缀

        # hidden layer
        hidden1 = tf.layers.dense(noise_img, n_units)  # dense :全连接层   inputs:输入该网络层的数据 units:输出的维度大小,改变inputs的最后一维
        # leaky ReLU    Leaky ReLU是给所有负值赋予一个非零斜率。
        hidden1 = tf.maximum(alpha * hidden1, hidden1)  # 返回二者较大值,因为alpha<1,所以正数时相当于返回自身,负数时返回alpha*自身,即Leaky ReLU过程
        # dropout
        hidden1 = tf.layers.dropout(hidden1, rate=0.2)  # dropout 防止过拟合,应用于输入,rate退出率   rate=0.2,输出单位将减少20%
        # logits & outputs
        logits = tf.layers.dense(hidden1, out_dim)  # 非归一化对数概率(又名logits) logits作为最后激活函数的前一层,输入到激活函数
        outputs = tf.tanh(logits, name='gen_output')  # 双曲正切函数,作用就是激活函数,值域为[-1,1]。类似于 sigmoid函数,不过sigmoid值域为[0,1]
        return logits, outputs


def get_discriminator(img, n_units, reuse=False, alpha=0.01):
    """
    判别器

    n_units: 隐层结点数量
    alpha: Leaky ReLU系数
    """
    # 创建discriminator变量空间,reuse = False
    with tf.variable_scope("discriminator"
  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值