【tensorflow学习】最简单的GAN 实现

1.GAN基本思想

生成式对抗网络GAN (Generative adversarial networks) 是Goodfellow 等在2014 年提出的一种生成式模型。GAN 的核心思想来源于博弈论的纳什均衡。它设定参与游戏双方分别为一个生成器(Generator)和一个判别器(Discriminator), 生成器捕捉真实数据样本的潜在分布, 并生成新的数据样本; 判别器是一个二分类器, 判别输入是真实数据还是生成的样本。
为了取得游戏胜利, 这两个游戏参与者需要不断优化, 各自提高自己的生成能力和判别能力, 这个学习优化过程就是寻找二者之间的一个纳什均衡。生成器和判别器均可以采用目前研究火热的深度神经网络.

2.GAN结构

这里写图片描述

GAN的计算流程与结构如图 所示。
任意可微分的函数都可以用来表示GAN 的生成器和判别器, 由此,我们用可微分函数D 和G 来分别表示判别器和生成器, 它们的输入分别为真实数据x 和随机变量z。G(z) 为由G 生成的尽量服从真实数据分布 pdata p d a t a 的样本。如果判别器的输入来自真实数据, 标注为1.如果输入样本为G(z), 标注为0。
这里D 的目标是实现对数据来源的二分类判别: 真(来源于真实数据x 的分布) 或者伪(来源于生成器的伪数据G(z)),而G 的目标是使自己生成的伪数据G(z) 在D 上的表现D(G(z)) 和真实数据x 在D 上的表现D(x)一致.

3.GAN 的学习方法

这里写图片描述

实际上是生成器和判别器的极大极小博弈:
D的目标是最大化V(D,G)
G的目标是最小化 max V(D,G)

4.GAN 实现

接下来实现一个vanilla GAN,生成器和判别器都是二层的神经网络。数据集采用mnist。
点击下载代码
实现结果是这样的:
这里写图片描述

接下来就来一步一步实现吧~

4.1 导入数据

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from tensorflow.examples.tutorials.mnist import input_data

sess = tf.InteractiveSession()

mb_size = 128
Z_dim = 100

mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)

4.2 声明变量


def weight_var(shape, name):
    return tf.get_variable(name=name, shape=shape, initializer=tf.contrib.layers.xavier_initializer())


def bias_var(shape, name):
    return tf.get_variable(name=name, shape=shape, initializer=tf.constant_initializer(0))


# discriminater net

X = tf.placeholder(tf.float32, shape=[None, 784], name='X')

D_W1 = weight_var([784, 128], 'D_W1')
D_b1 = bias_var([128], 'D_b1')

D_W2 = weight_var([128, 1], 'D_W2')
D_b2 = bias_var([1], 'D_b2')


theta_D = [D_W1, D_W2, D_b1, D_b2]


# generator net

Z = tf.placeholder(tf.float32, shape=[None, 100], name='Z')

G_W1 = weight_var([100, 128], 'G_W1')
G_b1 = bias_var([128], 'G_B1')

G_W2 = weight_var([128, 784], 'G_W2')
G_b2 = bias_var([784], 'G_B2')

theta_G = [G_W1, G_W2, G_b1, G_b2]

4.3 定义模型


def generator(z):
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)

    return G_prob


def discriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    D_logit = tf.matmul(D_h1, D_W2) + D_b2
    D_prob = tf.nn.sigmoid(D_logit)
    return D_prob, D_logit

G_sample = generator(Z)
D_real, D_logit_real = discriminator(X)
D_fake, D_logit_fake = discriminator(G_sample)

4.4 设定loss function

论文中的解释如下:
这里写图片描述

由于tensorflow只能做minimize,loss function可以写成如下:

D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
G_loss = -tf.reduce_mean(tf.log(D_fake))

值得注意的是,论文中提到,比起最小化 tf.reduce_mean(1 - tf.log(D_fake)) ,最大化tf.reduce_mean(tf.log(D_fake)) 更好。

另外一种写法是利用tensorflow自带的tf.nn.sigmoid_cross_entropy_with_logits 函数:

D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
    logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))

4.5 优化


D_optimizer = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
G_optimizer = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)

4.6 训练

def sample_Z(m, n):
    '''Uniform prior for G(Z)'''
    return np.random.uniform(-1., 1., size=[m, n])

sess.run(tf.global_variables_initializer())
for it in range(1000000):
   X_mb, _ = mnist.train.next_batch(mb_size)

    _, D_loss_curr = sess.run([D_optimizer, D_loss], feed_dict={
                              X: X_mb, Z: sample_Z(mb_size, Z_dim)})
    _, G_loss_curr = sess.run([G_optimizer, G_loss], feed_dict={
                              Z: sample_Z(mb_size, Z_dim)})

    if it % 1000 == 0:
        print('Iter: {}'.format(it))
        print('D loss: {:.4}'.format(D_loss_curr))
        print('G_loss: {:.4}'.format(G_loss_curr))
        print()

4.7 保存生成的图片

在4.6 训练中的代码加入一段,每1000 step保存16张生成图片:

def sample_Z(m, n):
    '''Uniform prior for G(Z)'''
    return np.random.uniform(-1., 1., size=[m, n])

def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):  # [i,samples[i]] imax=16
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig



if not os.path.exists('out/'):
    os.makedirs('out/')

sess.run(tf.global_variables_initializer())

i = 0
for it in range(1000000):
    if it % 1000 == 0:
        samples = sess.run(G_sample, feed_dict={
                           Z: sample_Z(16, Z_dim)})  # 16*784
        fig = plot(samples)
        plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
        i += 1
        plt.close(fig)

    X_mb, _ = mnist.train.next_batch(mb_size)

    _, D_loss_curr = sess.run([D_optimizer, D_loss], feed_dict={
                              X: X_mb, Z: sample_Z(mb_size, Z_dim)})
    _, G_loss_curr = sess.run([G_optimizer, G_loss], feed_dict={
                              Z: sample_Z(mb_size, Z_dim)})

    if it % 1000 == 0:
        print('Iter: {}'.format(it))
        print('D loss: {:.4}'.format(D_loss_curr))
        print('G_loss: {:.4}'.format(G_loss_curr))
        print()

5 reference

[1] http://wiseodd.github.io/techblog/2016/09/17/gan-tensorflow/
[2] Goodfellow, Ian, et al. “Generative adversarial nets.” Advances in Neural Information Processing Systems. 2014.
[3] 王坤峰, et al. “生成式对抗网络 GAN 的研究进展与展望.”

  • 12
    点赞
  • 100
    收藏
    觉得还不错? 一键收藏
  • 38
    评论
### 回答1: 使用TensorFlow训练并测试手写数字识别的MNIST数据集十分简单。首先,我们需要导入TensorFlow和MNIST数据集: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data 接下来,我们可以使用input_data.read_data_sets()函数加载MNIST数据集,其中参数为下载数据集的路径。我们可以将数据集分为训练集、验证集和测试集。这里我们将验证集作为模型的参数调整过程,测试集用于最终模型评估。 mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 接下来,我们可以使用TensorFlow创建一个简单的深度学习模型。首先,我们创建一个输入占位符,用于输入样本和标签。由于MNIST数据集是28x28的图像,我们将其展平为一个784维的向量。 x = tf.placeholder(tf.float32, [None, 784]) y = tf.placeholder(tf.float32, [None, 10]) 接下来,我们可以定义一个简单的全连接神经网络,包含一个隐藏层和一个输出层。我们使用ReLU激活函数,并使用交叉熵作为损失函数。 hidden_layer = tf.layers.dense(x, 128, activation=tf.nn.relu) output_layer = tf.layers.dense(hidden_layer, 10, activation=None, name="output") cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=output_layer, labels=y)) 然后,我们可以使用梯度下降优化器来最小化损失函数,并定义正确预测的准确率。这样就完成了模型的构建。 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) correct_prediction = tf.equal(tf.argmax(output_layer, 1), tf.argmax(y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 接下来,我们可以在一个会话中运行模型。在每次迭代中,我们从训练集中随机选择一批样本进行训练。在验证集上进行模型的参数调整过程,最后在测试集上评估模型的准确率。 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for i in range(1000): batch_x, batch_y = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_x, y: batch_y}) val_accuracy = sess.run(accuracy, feed_dict={x: mnist.validation.images, y: mnist.validation.labels}) print("Validation Accuracy:", val_accuracy) test_accuracy = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}) print("Test Accuracy:", test_accuracy) 通过这个简单的代码,我们可以使用TensorFlow训练并测试MNIST数据集,并得到测试集上的准确率。 ### 回答2: gan tensorflow mnist是指使用TensorFlow框架训练生成对抗网络(GAN)来生成手写数字图像的任务。 首先,手写数字数据集是一个非常常见且经典的机器学习数据集。MNIST数据集包含了由0到9之间的手写数字的图像样本。在gan tensorflow mnist任务中,我们的目标是使用GAN来生成与这些手写数字样本类似的新图像。 GAN是一种由生成器和判别器组成的模型。生成器任务是生成看起来真实的图像,而判别器任务是判断给定图像是真实的(来自训练数据集)还是生成的(来自生成器)。这两个模型通过对抗训练来相互竞争和提高性能。 在gan tensorflow mnist任务中,我们首先需要准备和加载MNIST数据集。利用TensorFlow的函数和工具,我们可以轻松地加载和处理这些图像。 接下来,我们定义生成器和判别器模型。生成器模型通常由一系列的卷积、反卷积和激活函数层组成,以逐渐生成高质量的图像。判别器模型则类似于一个二分类器,它接收图像作为输入并输出真实或生成的预测结果。 我们使用TensorFlow的优化器和损失函数定义GAN模型的训练过程。生成器的目标是误导判别器,使其将生成的图像误认为是真实图像,从而最大限度地降低判别器的损失函数。判别器的目标是准确地区分真实和生成的图像,从而最大限度地降低自身的损失函数。 最后,我们使用训练数据集来训练GAN模型。通过多次迭代,生成器和判别器的性能会随着时间的推移而得到改善。一旦训练完成,我们可以使用生成器模型来生成新的手写数字图像。 总结来说,gan tensorflow mnist是指使用TensorFlow框架训练生成对抗网络来生成手写数字图像的任务。通过定义生成器和判别器模型,使用优化器和损失函数进行训练,我们可以生成类似于MNIST数据集手写数字的新图像。 ### 回答3: 用TensorFlow训练MNIST数据集可以实现手写数字的分类任务。首先我们需要导入相关库和模块,如tensorflow、keras以及MNIST数据集。接着,我们定义模型的网络结构,可以选择卷积神经网络(CNN)或者全连接神经网络(DNN)。对于MNIST数据集,我们可以选择使用CNN,因为它能更好地处理图像数据。 通过调用Keras中的Sequential模型来定义网络结构,可以添加多个层(如卷积层、池化层、全连接层等),用来提取特征和做出分类。其中,输入层的大小与MNIST图片的大小相对应,输出层的大小等于类别的数量(即0~9的数字)。同时,我们可以选择优化器(如Adam)、损失函数(如交叉熵)和评估指标(如准确率)。 接下来,我们用模型编译来配置模型的学习过程。在编译时,我们可以设置优化器、损失函数和评估指标。然后,我们用训练数据对模型进行拟合,通过迭代优化来调整模型的权重和偏置。迭代次数可以根据需要进行调整,以达到训练效果的需求。 训练结束后,我们可以使用测试数据对模型进行评估,获得模型在测试集上的准确率。最后,我们可以使用模型对新的未知数据进行预测,得到相应的分类结果。 综上所述,使用TensorFlow训练MNIST数据集可以实现手写数字的分类任务,通过定义模型结构、编译模型、拟合模型、评估模型和预测来完成整个过程。这个过程需要一定的编程知识和理解深度学习的原理,但TensorFlow提供了方便的api和文档,使我们能够相对容易地实现这个任务。
评论 38
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值