生成式对抗网络GAN生成手写数字

版权声明:本文为博主原创文章,遵循 CC 4.0 by-sa 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/yql_617540298/article/details/81913212

GAN(Generative Adversarial Networks)是较为火热的一种神经网络,具有较多的优势和特点。

一、GAN

1. 原理

源自于零和博弈(zero-sum game),包括生成模型(generative model, G)和判别模型(discriminative model, D)。

G,D的主要功能是:

(1)G是一个生成式的网络,它接收一个随机的噪声z(随机数),通过这个噪声生成图像;

(2)D是一个判别网络,判别一张图片是不是“真实的”。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。

训练过程中,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D。而D的目标就是尽量辨别出G生成的假图像和真实的图像。这样,G和D构成了一个动态的“博弈过程”,最终的平衡点即纳什均衡点。

2. 特点

(1)相比较传统的模型,他存在两个不同的网络,而不是单一的网络,并且训练方式采用的是对抗训练方式;

(2)GAN中G的梯度更新信息来自判别器D,而不是来自数据样本。

3. 优点

(1)GAN采用的是一种无监督的学习方式训练,可以被广泛用在无监督学习和半监督学习领域

(2)相比VAE,GAN没有变分下界,如果判别器训练良好,那么生成器可以完美的学习到训练样本的分布。换句话说,GANs是渐进一致的,但是VAE是有偏差的;

(3)GAN应用到一些场景上,比如图片风格迁移,超分辨率,图像补全,去噪,避免了损失函数设计的困难,只要有一个基准,直接上判别器,剩下的就交给对抗训练了。

二、MNIST数据集

下载地址:https://download.csdn.net/download/yql_617540298/10618317

  • Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
  • Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
  • Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
  • Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)

三、GAN生成MNIST手写数字

采用Tensorflow框架,生成手写数字。

tensorflow安装:

pip install tensorflow-gpu==版本号

检查tensorflow版本:

代码如下:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

class G_Net:

    def forward(self, x, reuse=False):
        with tf.variable_scope("gnet", reuse=reuse):
            
            full_connect= tf.layers.dense(x, 4*4*512)
            flatten = tf.reshape(full_connect, [-1, 4, 4, 512])

            layer_1 = tf.nn.leaky_relu(tf.layers.batch_normalization(flatten, training=True))
            layer_1_dout = tf.nn.dropout(layer_1, keep_prob=0.8)

            # 4 * 4 * 512 to 7 x 7 x 256
            layer_2 = tf.nn.leaky_relu(tf.layers.batch_normalization(tf.layers.conv2d_transpose(layer_1_dout, 256, 4, strides=1, padding='valid'), training=True))
            layer_2_dout = tf.nn.dropout(layer_2, keep_prob=0.8)

            # 7 x 7 256 to 14 x 14 x 128
            layer_3 = tf.nn.leaky_relu(tf.layers.batch_normalization(tf.layers.conv2d_transpose(layer_2_dout, 128, 3, strides=2, padding='same'), training=True))
            layer_3_dout = tf.nn.dropout(layer_3, keep_prob=0.8)

            # 14 x 14 x 128 to 28 x 28 x 1
            f_img = tf.layers.conv2d_transpose(layer_3_dout, 1, 3, strides=2, padding='same')
            
            out = tf.tanh(f_img)
            return out

    def getParam(self):
        return tf.get_collection(tf.GraphKeys.VARIABLES, scope="gnet")

class D_Net:

    def forward(self, x, reuse=False):
        with tf.variable_scope("dnet", reuse=reuse):
            
            # 28 x 28 x 1 to 14 x 14 x 128
            layer_1 = tf.nn.leaky_relu(tf.layers.conv2d(x, 128, kernel_size=3, strides=2, padding="same"))
            layer_1_dout = tf.nn.dropout(layer_1, keep_prob=0.8)

            # # 14 x 14 x 128 to 7 x 7 x 256
            layer_2 = tf.nn.leaky_relu(tf.layers.batch_normalization(tf.layers.conv2d(layer_1_dout, 256, 3, strides=2, padding="same"), training=True))
            layer_2_dout = tf.nn.dropout(layer_2, keep_prob=0.8)

            # 7 x 7 x 256 to 4 x 4 x 512
            layer_3 = tf.nn.leaky_relu(tf.layers.batch_normalization(tf.layers.conv2d(layer_2_dout, 512, 3, strides=2, padding="same"), training=True))
            layer_3_dout = tf.nn.dropout(layer_3, keep_prob=0.8)

            # 4 x 4 x 512 to 4 * 4* 512 x 1
            flatten = tf.reshape(layer_3_dout, (-1, 4*4*512))
            # logits = tf.sigmoid(tf.layers.dense(flatten, 1))
            logits = tf.layers.dense(flatten, 1)
            return logits

    def getParam(self):
        return tf.get_collection(tf.GraphKeys.VARIABLES, scope="dnet")

class GAN_NET:

    def __init__(self):
        self.f_xs = tf.placeholder(dtype=tf.float32, shape=[None, 100])
        self.f_ys = tf.placeholder(dtype=tf.float32, shape=[None, 1])

        self.t_xs = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28, 1])
        self.t_ys = tf.placeholder(dtype=tf.float32, shape=[None ,1])

        self.gnet = G_Net()
        self.dnet = D_Net()

        self.forward()
        self.backward()

    def forward(self):
        
        self.g_out = self.gnet.forward(self.f_xs)
        self.g_d_out = self.dnet.forward(self.g_out)
        self.t_d_out = self.dnet.forward(self.t_xs, True)

    def backward(self):
        # self.d_loss = tf.reduce_mean((self.t_d_out - self.t_ys) ** 2)
        self.d_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.g_d_out, labels=self.f_ys)
                                     + tf.nn.sigmoid_cross_entropy_with_logits(logits=self.t_d_out, labels=self.t_ys))
        self.d_opt = tf.train.AdamOptimizer().minimize(self.d_loss, var_list=self.dnet.getParam())

        # self.g_loss = tf.reduce_mean((self.g_d_out - self.f_ys) ** 2)
        self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.g_d_out, labels=self.f_ys))
        self.g_opt = tf.train.AdamOptimizer().minimize(self.g_loss, var_list=self.gnet.getParam())

if __name__ == '__main__':

    gan_net = GAN_NET()

    save = tf.train.Saver(max_to_keep=1)

    d_batch = 10
    g_batch = 80

    plt.ion()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        for i in range(100000000):
            xs, _ = mnist.train.next_batch(d_batch)
            t_xs = np.reshape(xs, newshape=(d_batch,28,28,-1))
            t_ys = np.ones(shape=[d_batch,1])

            f_xs = np.random.normal(-1,1,size=(d_batch,100))
            f_ys = np.zeros(shape=[d_batch, 1])

            d_loss, _ = sess.run([gan_net.d_loss, gan_net.d_opt], feed_dict={
                gan_net.f_xs:f_xs, gan_net.f_ys:f_ys,
                gan_net.t_xs:t_xs, gan_net.t_ys:t_ys
            })


            _f_xs = np.random.normal(-1,1,size=(g_batch,100))
            _t_ys = np.ones(shape=[g_batch, 1])

            g_loss,_ = sess.run([gan_net.g_loss,gan_net.g_opt], feed_dict={gan_net.f_xs:_f_xs,gan_net.f_ys:_t_ys})

            print("i---",i," d_loss=",d_loss,"-- g_loss",g_loss)

            if (i+1) % 100 == 0:
                save_path = save.save(sess, "./save/gan_mnist")
                print(save_path)

            if (i+1) % 5 == 0:
                # save.restore(sess, "./save/gan_mnist")
                t_data = np.random.normal(-1,1,size=(1,100))
                img = sess.run(gan_net.g_out, feed_dict={gan_net.f_xs:t_data})
                img = (img + 1) * 127.5
                img = np.array(img, np.uint8)
                img = np.reshape(img[0],newshape=[28,28])
                plt.imshow(img)
                plt.pause(1)

训练:

展开阅读全文

没有更多推荐了,返回首页