GAN(博弈)- 生成手写数据集

2.1 GAN的基本结构

GAN的主要结构包括一个生成器G(Generator)和一个判别器D(Discriminator)。

 

我们现在拥有大量的手写数字的数据集,我们希望通过GAN生成一些能够以假乱真的手写字图片。主要由如下两个部分组成:

  1. 定义一个模型来作为生成器(图三中蓝色部分Generator),能够输入一个向量,输出手写数字大小的像素图像。
  2. 定义一个分类器来作为判别器(图三中红色部分Discriminator)用来判别图片是真的还是假的(或者说是来自数据集中的还是生成器中生成的),输入为手写图片,输出为判别图片的标签。

2.2 GAN的训练方式

前面已经定义了一个生成器(Generator)来生成手写数字,一个判别器(Discrimnator)来判别手写数字是否是真实的,和一些真实的手写数字数据集。那么我们怎样来进行训练呢?

2.2.1 关于生成器

对于生成器,输入需要一个n维度向量,输出为图片像素大小的图片。因而首先我们需要得到输入的向量。

Tips: 这里的生成器可以是任意可以输出图片的模型,比如最简单的全连接神经网络,又或者是反卷积网络等。这里大家明白就好。

这里输入的向量我们将其视为携带输出的某些信息,比如说手写数字为数字几,手写的潦草程度等等。由于这里我们对于输出数字的具体信息不做要求,只要求其能够最大程度与真实手写数字相似(能骗过判别器)即可。所以我们使用随机生成的向量来作为输入即可,这里面的随机输入最好是满足常见分布比如均值分布,高斯分布等。

Tips: 假如我们后面需要获得具体的输出数字等信息的时候,我们可以对输入向量产生的输出进行分析,获取到哪些维度是用于控制数字编号等信息的即可以得到具体的输出。而在训练之前往往不会去规定它。

2.2.2 关于判别器

对于判别器不用多说,往往是常见的判别器,输入为图片,输出为图片的真伪标签。

Tips: 同理,判别器与生成器一样,可以是任意的判别器模型,比如全连接网络,或者是包含卷积的网络等等。

 

2.2.3 如何训练

上面进一步说明了生成器和判别器,接下来说明如何进行训练。

基本流程如下:



可以看到在(a)状态处于最初始的状态的时候,生成器生成的分布和真实分布区别较大,并且判别器判别出样本的概率不是很稳定,因此会先训练判别器来更好地分辨样本。
通过多次训练判别器来达到(b)样本状态,此时判别样本区分得非常显著和良好。然后再对生成器进行训练。
训练生成器之后达到(c)样本状态,此时生成器分布相比之前,逼近了真实样本分布。
经过多次反复训练迭代之后,最终希望能够达到(d)状态,生成样本分布拟合于真实样本分布,并且判别器分辨不出样本是生成的还是真实的(判别概率均为0.5)。也就是说我们这个时候就可以生成出非常真实的样本啦,目的达到。

 

import os
import pickle
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
#from tensorflow.examples.tutorials.mnist import input_data
from tensorflow_core.examples.tutorials.mnist import input_data
from datetime import datetime
 # 包含了很多数据集,第一次使用需要下载
# mnist = tf.keras.datasets.mnist
# (X_train, y_train), (X_test, y_test) = mnist.load_data()
# #X_train = X_train.reshape(28*28)
# print(len(X_train))


class MnistModel:


    def __init__(self):
        # mnist测试集
        self.mnist = input_data.read_data_sets('./mnist/raw', one_hot=True)
        #self.mnist = mnist.load_data()
        #self.X_train = X_train
        # 图片大小
        self.img_size = self.mnist.train.images[0].shape[0]
        #self.img_size = self.X_train[0].shape[0]
        # 每步训练使用图片数量
        self.batch_size = 64
        # 图片分块数量
        self.chunk_size = self.mnist.train.num_examples // self.batch_size
        #self.chunk_size = 60000 // self.batch_size
        # 训练循环次数
        self.epoch_size = 300
        # 抽取样本数
        self.sample_size = 25
        # 生成器判别器隐含层数量
        self.units_size = 128
        # 学习率
        self.learning_rate = 0.001
        # 平滑参数
        self.smooth = 0.1

    @staticmethod
    def generator_graph(fake_imgs, units_size, out_size, alpha=0.01):
        # 生成器与判别器属于两个网络 定义不同scope
        with tf.variable_scope('generator'):
            # 构建一个全连接层
            layer = tf.layers.dense(fake_imgs, units_size)
            # leaky ReLU 激活函数
            relu = tf.maximum(alpha * layer, layer)
            # dropout 防止过拟合
            drop = tf.layers.dropout(relu, rate=0.2)
            # logits
            # out_size应为真实图片size大小
            logits = tf.layers.dense(drop, out_size)
            # 激活函数 将向量值限定在某个区间 与 真实图片向量类似
            # 这里tanh的效果比sigmoid好一些
            # 输出范围(-1, 1) 采用sigmoid则为[0, 1]
            outputs = tf.tanh(logits)
            return logits, outputs

    @staticmethod
    def discriminator_graph(imgs, units_size, alpha=0.01, reuse=False):
        with tf.variable_scope('discriminator', reuse=reuse):
            # 构建全连接层
            layer = tf.layers.dense(imgs, units_size)
            # leaky ReLU 激活函数
            relu = tf.maximum(alpha * layer, layer)
            # logits
            # 判断图片真假 out_size直接限定为1
            logits = tf.layers.dense(relu, 1)
            # 激活函数
            outputs = tf.sigmoid(logits)
            return logits, outputs

    @staticmethod
    def loss_graph(real_logits, fake_logits, smooth):
        # 生成器图片loss
        # 生成器希望判别器判断出来的标签为1
        gen_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, labels=tf.ones_like(fake_logits) * (1 - smooth)))
        # 判别器识别生成器图片loss
        # 判别器希望识别出来的标签为0
        fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, labels=tf.zeros_like(fake_logits)))
        # 判别器识别真实图片loss
        # 判别器希望识别出来的标签为1
        real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logits, labels=tf.ones_like(real_logits) * (1 - smooth)))
        # 判别器总loss
        dis_loss = tf.add(fake_loss, real_loss)
        return gen_loss, fake_loss, real_loss, dis_loss

    @staticmethod
    def optimizer_graph(gen_loss, dis_loss, learning_rate):
        # 所有定义变量
        train_vars = tf.trainable_variables()
        # 生成器变量
        gen_vars = [var for var in train_vars if var.name.startswith('generator')]
        # 判别器变量
        dis_vars = [var for var in train_vars if var.name.startswith('discriminator')]
        # optimizer
        # 生成器与判别器作为两个网络需要分别优化
        gen_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(gen_loss, var_list=gen_vars)
        dis_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(dis_loss, var_list=dis_vars)
        return gen_optimizer, dis_optimizer

    def train(self):
        # 真实图片与混淆图片
        # 不确定输入图片数量 用None
        real_imgs = tf.placeholder(tf.float32, [None, self.img_size], name='real_imgs')
        fake_imgs = tf.placeholder(tf.float32, [None, self.img_size], name='fake_imgs')

        # 生成器
        gen_logits, gen_outputs = self.generator_graph(fake_imgs, self.units_size, self.img_size)
        # 判别器对真实图片
        real_logits, real_outputs = self.discriminator_graph(real_imgs, self.units_size)
        # 判别器对生成器图片
        # 公用参数所以要reuse
        fake_logits, fake_outputs = self.discriminator_graph(gen_outputs, self.units_size, reuse=True)

        # 损失
        gen_loss, fake_loss, real_loss, dis_loss = self.loss_graph(real_logits, fake_logits, self.smooth)
        # 优化
        gen_optimizer, dis_optimizer = self.optimizer_graph(gen_loss, dis_loss, self.learning_rate)

        # 开始训练
        saver = tf.train.Saver()
        step = 0
        # 指定占用GPU比例
        # tensorflow默认占用全部GPU显存 防止在机器显存被其他程序占用过多时可能在启动时报错
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
        with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
            sess.run(tf.compat.v1.global_variables_initializer())
            for epoch in range(self.epoch_size):
                for _ in range(self.chunk_size):
                    batch_imgs, _ = self.mnist.train.next_batch(self.batch_size)
                    batch_imgs = batch_imgs * 2 - 1
                    # generator的输入噪声
                    noise_imgs = np.random.uniform(-1, 1, size=(self.batch_size, self.img_size))
                    # 优化
                    _ = sess.run(gen_optimizer, feed_dict={fake_imgs: noise_imgs})
                    _ = sess.run(dis_optimizer, feed_dict={real_imgs: batch_imgs, fake_imgs: noise_imgs})
                    step += 1
                # 每一轮结束计算loss
                # 判别器损失
                loss_dis = sess.run(dis_loss, feed_dict={real_imgs: batch_imgs, fake_imgs: noise_imgs})
                # 判别器对真实图片
                loss_real = sess.run(real_loss, feed_dict={real_imgs: batch_imgs, fake_imgs: noise_imgs})
                # 判别器对生成器图片
                loss_fake = sess.run(fake_loss, feed_dict={real_imgs: batch_imgs, fake_imgs: noise_imgs})
                # 生成器损失
                loss_gen = sess.run(gen_loss, feed_dict={fake_imgs: noise_imgs})

                print(datetime.now().strftime('%c'), ' epoch:', epoch, ' step:', step, ' loss_dis:', loss_dis,
                      ' loss_real:', loss_real, ' loss_fake:', loss_fake, ' loss_gen:', loss_gen)
            model_path = os.getcwd() + os.sep + "mnist.model"
            saver.save(sess, model_path, global_step=step)

    def gen(self):
        # 生成图片
        sample_imgs = tf.placeholder(tf.float32, [None, self.img_size], name='sample_imgs')
        gen_logits, gen_outputs = self.generator_graph(sample_imgs, self.units_size, self.img_size)
        saver = tf.train.Saver()
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            saver.restore(sess, tf.train.latest_checkpoint('.'))
            sample_noise = np.random.uniform(-1, 1, size=(self.sample_size, self.img_size))
            samples = sess.run(gen_outputs, feed_dict={sample_imgs: sample_noise})
        with open('samples.pkl', 'wb') as f:
            pickle.dump(samples, f)

    @staticmethod
    def show():
        # 展示图片
        with open('samples.pkl', 'rb') as f:
            samples = pickle.load(f)
        fig, axes = plt.subplots(figsize=(7, 7), nrows=5, ncols=5, sharey=True, sharex=True)
        for ax, img in zip(axes.flatten(), samples):
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
            ax.imshow(img.reshape((28, 28)), cmap='Greys_r')
        plt.show()

 

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
GAN生成对抗网络)是一种能够生成逼真图像的深度学习模型。想要利用GAN生成自定义数据集,需要遵循以下步骤: 1. 收集和准备数据集 首先要收集和准备用于GAN生成数据集数据集可以来自任何来源,如摄像头、图像库、网络。关键是确保数据集数量足够多,且具有明确的主题和风格。 2. 设计GAN网络架构 设计GAN网络架构可以使用开源GAN库进行训练。如果没有深度学习背景,可以采用现成的GAN网络架构,如DCGAN(深度卷积GAN)。 3. 训练GAN网络 将设计好的GAN网络架构与数据集进行训练。在训练过程中需要特别注意超参数,包括学习率、批处理大小、优化算法等,可以使用交叉验证的方式进行优化。 4. 生成数据集 GAN在经过训练之后,就可以用来生成自定义的新数据集。可以手动调整生成图像的数量、风格和主题,以便生成需要的数据集。 5. 数据集标记 生成数据集需要进行标记,以便进行进一步的模型训练或评估。可以按照需要进行图像分类、目标检测、分割等标记,这样可以提高数据的质量,为后续任务的完成提供基础。 总的来说,GAN生成自定义数据集需要经过数据收集、GAN网络架构设计、训练GAN网络、生成数据集数据集标记等多个环节。只有经过系统性地处理和整合,才能生成质量较高的自定义数据集,为后续深度学习模型的训练和评估工作提供基础。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值