生成式对抗网络GAN---生成mnist手写数字图像

生成式对抗网络GAN—生成mnist手写数字图像

一、GAN的基本结构

生成式对抗网络是一种无监督深度学习模型

主要组成:
生成模型G (generative model)
判别模型D (discriminative model)

网络的训练过程是这样的。首先,我们有一些随机噪声Z,喂入生成模型里面它会生成相应的数据,这个数据是假的。而相对于这个假的数据,我们有一些进行标注的或者是识别的真实数据。有了这个真数据和假数据之后,我们分别喂入判别网络模型,让他去学习里面的特征,数据类型,给出真假的一个判断。

举一个更通俗的例子来帮助大家理解。我们可以把生成模型想象成为一个伪造字画的造假的人。判别模型,可以想象成一个鉴定字画的鉴定师。开始的时候他们都还是新手,需要不断地学习,G的任务是生成以假乱真的字画,D的任务是从真假字画当中判别它的真假。
具体过程是这样的,首先我们给G一些材料也就是噪声,然后他会生成一幅字画,D会学习假字画特征和真字画的特征,然后去判定G生成这个到底是真的还是假的。开始的时候,G生成的这些字画是非常的拙劣,所以D一开始会判定为假。然后这个假的结果会返回到生成模型,他会在思考怎么才能生成更好,以假乱真的数据。如此这样的循环往复就是一个对抗的过程,最后,我们的目的就是,给生成模型寄一些随机的噪声,它就能生成以假乱真的一些数据骗过我们的判别模型。

对抗过程,用数学来表达,是一个二元极小极大值博弈,关于判别网络和生成网络的一个价值函数。

第一个是关于D的训练网络的一个函数,我们希望最大化log D(x),训练网络能够最大概率的分配这个训练模型的标签,也就是说我们希望鉴定师,他能够对真假数据的判定越来越准确。

第二个我们希望最小化log(1-D(G(Z)))。意思是,训练网络G生成能够欺骗网络D的数据,最大化D的损失。假设真数据为1,假数据为零。把G生成的数据拿去判别,结果应该为0,但是站在G的角度,希望能判别为1。

二、mnist代码实例

导入相关的包
tensorflow (谷歌开源的机器学习平台,有很多关于深度学习训练的函数)
numpy (用于数值计算)
matplotlib (画图用的)

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os

读入mnist数据集的数据
将代码和数据集文件夹(mnistdata)放在同一目录下

mnist = input_data.read_data_sets('./mnistdata', one_hot=True)

随机噪声函数
采用均匀分布 np.randow.uniform(low, high,size)
返回值是一个从-1到1的噪声

def sample_z(m,n):
    return np.random.uniform(-1.,1.,size=[m,n])

生成模型输入和参数初始化

Z = tf.placeholder(tf.float32, shape=[None, 100])
G_W1 = tf.get_variable("G_W1", shape=[100,128],initializer=tf.contrib.layers.xavier_initializer())
G_b1= tf.Variable(tf.zeros(shape=[128]))
G_W2 = tf.get_variable("G_W2", shape=[128,784],initializer=tf.contrib.layers.xavier_initializer())
G_b2= tf.Variable(tf.zeros(shape=[784]))
theta_G = [G_W1, G_W2, G_b1, G_b2]

生成模型

def Gene(Z):
    G_h1 = tf.nn.relu(tf.matmul(Z, G_W1)+G_b1)#第一层矩阵相乘后激活
    G_log_prob = tf.matmul(G_h1, G_W2)+G_b2#第二层
    return tf.nn.sigmoid(G_log_prob)

判别模型输入和参数初始化
与生成模型基本相同,但有些数据需要对应,从784到1

X = tf.placeholder(tf.float32, shape=[None, 784])
D_W1 = tf.get_variable("D_W1", shape=[784,128],initializer=tf.contrib.layers.xavier_initializer())
D_b1= tf.Variable(tf.zeros(shape=[128]))
D_W2 = tf.get_variable("D_W2", shape=[128,1],initializer=tf.contrib.layers.xavier_initializer())
D_b2= tf.Variable(tf.zeros(shape=[1]))
theta_D = [D_W1, D_W2, D_b1, D_b2]

判别模型
比生成模型多返回第二层

def Disc(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1)+D_b1)
    D_logit = tf.matmul(D_h1, D_W2)+D_b2
    D_prob = tf.nn.sigmoid(D_logit)
    return D_prob, D_logit

画图

def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

喂入数据

G_sample = Gene(Z)
D_real, D_logit_real = Disc(X)
D_fake, D_logit_fake = Disc(G_sample)

计算G和D的损失(loss)均值
交叉熵(度量两个概率分布间的差异性信息),差异越大,交叉熵越大

D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = D_logit_real, labels=tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))

Adam算法优化器,学习率为0.0001

D_solver = tf.train.AdamOptimizer(0.0001).minimize(D_loss,var_list=theta_D)
G_solver = tf.train.AdamOptimizer(0.0001).minimize(G_loss, var_list=theta_G)

图像输出的位置

if not os.path.exists('out/'):
    os.makedirs('out/')

开始训练
一共迭代了1000000次,每1000次会生成一张图片

i=0
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    mb_size = 128
    Z_dim = 100
    for it in range(1000000):
        if it % 1000 == 0:
            samples = sess.run(G_sample, feed_dict={Z: sample_z(16, Z_dim)})

            fig = plot(samples)
            plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
            i += 1
            plt.close(fig)

        X_mb, _ = mnist.train.next_batch(mb_size)

        _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_z(mb_size, Z_dim)})
        _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_z(mb_size, Z_dim)})

        if it % 1000 == 0:
            print('Iter: {}'.format(it),'D loss: {}'.format(D_loss_curr),'G_loss: {}'.format(G_loss_curr))

三、生成结果

从左到右,从上到下依次是迭代了1000次,333000次,666000次,999000次后得到的图像

完整代码
https://github.com/SinsoledadFairy/mnist-gan

  • 5
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值