CNTK API文档翻译(20)——GAN处理MSIST数据基础

完成本期教程需要完成本系列的第四篇教程。

介绍

生成模型在深度学习的半监督或者非监督学习领域引起了广泛的专注,这些领域传统上都是使用判别模型的。生成模型的思想是线收集某个研究领域巨量的数据,然后训练得到一个可以生成这样的数据集的模型。这是一个需要大量训练和海量数据的热门研究领域。根据OpenAI博客的观点,这种方法可能可以用于进行计算机辅助艺术的创作,或者根据语言描述来对图片进行一些改变比如“让我的笑容更明媚”。这种方法目前已被用于图像去燥、图像修复、增加图像分辨率、图像结构识别,而且在增强学习、神经网络预训练这种标记数据代价高昂的领域,也有深入的研究。

生成模型能够产生与现实数据高度相似的内容(图像,声音等)是非常困难的。生成对抗网络(Generative Adversarial Network,GAN)是实现上诉描述的方法之一。一个来自LeCun summarizes的文章(地址:https://www.quora.com/What-are-some-recent-and-potentially-upcoming-breakthroughs-in-deep-learning)总结了GAN和GAN近十年的发展,我们在此展示如何使用CNTK来创建简单的GAN来生成模拟的MNIST数据。

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import os

import cntk as C
import cntk.tests.test_utils

# (only needed for our build system)
cntk.tests.test_utils.set_device_from_pytest_env() 
# fix a random seed for CNTK components
C.cntk_py.set_fixed_random_seed(1) 

我们设定了两种运行模式:

  • 快速模式:isFast变量设置成True。这是我们的默认模式,在这个模式下我们会训练更少的次数,也会使用更少的数据,这个模式保证功能的正确性,但训练的结果还远远达不到可用的要求。
  • 慢速模式:我们建议学习者在学习的时候试试将isFast变量设置成False,这会让学习者更加了解本教程的内容。

  • 注意如果isFast被设为False,在有GPU的机器上代码将运行几个小时。你可以试试通过吧num_minibatches设置成一个较小的数字比如20000,减少循环次数,不过带来的代价就是生成图像质量的降低。
isFast = True

数据读取

GAN网络的输入将会是一个由随机数组成的向量。在训练结束是,GNA学会生成像MNIST数据集中一样的手写数字的图片。我们将使用与第四期下载的数据,一些关于数据格式的讨论和读取方法在之前的教程中有涉及到。在本教程中,只要知道下面的方法返回一个用来从MNIST数据集中生成图像的对象。因为我们是在创建一个非监督学习模型,我们只读取features,而不管labels。

# Ensure the training data is generated and available for this tutorial
# We search in two locations in the toolkit for the cached MNIST data set.

data_found = False
for data_dir in [os.path.join("..", "Examples", "Image", "DataSets", "MNIST"),
                 os.path.join("data", "MNIST")]:
    train_file = os.path.join(data_dir, "Train-28x28_cntk_text.txt")
    if os.path.isfile(train_file):
        data_found = True
        break

if not data_found:
    raise ValueError("Please generate the data by completing CNTK 103 Part A")

print("Data directory is {0}".format(data_dir))


def create_reader(path, is_training, input_dim, label_dim):
    deserializer = C.io.CTFDeserializer(
        filename = path,
        streams = C.io.StreamDefs(
            labels_unused = C.io.StreamDef(field = 'labels', shape = label_dim, is_sparse = False),
            features = C.io.StreamDef(field = 'features', shape = input_dim, is_sparse = False
            )
        )
    )
    return C.io.MinibatchSource(
        deserializers = deserializer,
        randomize = is_training,
        max_sweeps = C.io.INFINITELY_REPEAT if is_training else 1
    )

我们用于训练GAN的随机噪音使用noise_sample方法随机生成一些[-1,1]之间正态分布的噪音样本。

np.random.seed(123)
def noise_sample(num_samples):
    return np.random.uniform(
        low = -1.0,
        high = 1.0,
        size = [num_samples, g_input_dim]        
    ).astype(np.float32)

模型创建

GNA网络由两个子网络组成,一个叫生成器(Generator,G),一个叫判别器(Discriminator ,D)。

  • 生成器以随机噪音向量z为输入参数,努力生成与MNIST数据集中的真实图像($x$)相似的合成图像($x^*$)
  • 判别器努力区分真实图像($x$)和合成图像($x^*$)之间的区别。

image

在每轮训练中,生成器都会生成更加真实的合成图像(也就是减少合成图像和真是图像之间的差),同时判别器最大化给真实和生成的图像帖对真实或生成的标签的概率。GNA两个子网络中的冲突导致他收敛于一个平衡,此时生成器生成看起来很像MNIST图像的合成照片,判别器可以最多的随机猜测那个图片是真实的,哪个图片是合成的。训练的结果就是生成模型以随机的输入数字得到逼真的MNIST图像。

模型配置

首先,我们设置一些模型结构和训练超参数。

  • 生成网络是一个有一个隐藏层的全连接网络,输入数据是一个100维随机向量,输出数据会是一个784维的向量,对应28×28图像的扁平状态。判别器也是一个单层全连接网络,以生成器生成的784维向量或来自真实MNIST数据集的784维向量作为输入,输出一个代表输入数据是真实MNIST数据概率的标量。

模型构成

我们为我们的模型构建计算图,一个给生成器一个给判别器。首先我们我们创建一些模型结构参数。

  • 生成器输入100维随机向量($z$)输出一个784维的向量,对应28×28合成图像($x^*$)的扁平状态。在本教程中,我们简单将我们的生成器构造为两个全连接层。我们在最后一层使用tanh激活函数确保生成器函数的输出值在闭区间[-1,1]之间。因为之前也将MNIST图像映射到了这个范围内,所以这步操作是有必要的。
  • 判别器输入从生成器中输出的或者来自真实MNIST图像的784维向量($x^*$),输出输入图像是真实MNIST图像的概率。我们也使用两个全连接层构建判别器,最后一层使用sigmoid激活函数,以此保证判别器的输出值是一个有效的概率。
# architectural parameters
g_input_dim = 100
g_hidden_dim = 128
g_output_dim = d_input_dim = 784
d_hidden_dim = 128
d_output_dim = 1

def generator(z):
    with C.layers.default_options(init = C.xavier()):
        h1 = C.layers.Dense(g_hidden_dim, activation = C.relu)(z)
        return C.layers.Dense(g_output_dim, activation = C.tanh)(h1)

def discriminator(x):
    with C.layers.default_options(init = C.xavier()):
        h1 = C.layers.Dense(d_hidden_dim, activation = C.relu)(x)
        return C.layers.Dense(d_output_dim, activation = C.sigmoid)(h1)

我们使用的取样包数大小是1024,固定学习速率0.0005.如果使用快速模式我们只训练300轮以证明其功能正确性。

注意:在慢速模式,结果看起来会比快速模式好得多,不过根据你训练电脑的配置,你可能会登上几个小时到十几个小时不等。一般来说,取样包训练的越多,生成的图像越逼真。

# training config
minibatch_size = 1024
num_minibatches = 300 if isFast else 40000
lr = 0.00005

构建计算图

计算图的剩下部分主要用于协调训练算法和参数更新,这由于以下原因对GAN十分困难。

  • 第一,判别器必须既用于真实MNIST图像,也用于生成器函数生成的模拟图像。一种在计算图上记录上诉状态的方法是创建一个判别器函数输出的克隆副本,但是用不同的输入。在副本函数中设置method=share确保不同方式使用的判别器使用一样的参数。
  • 第二,我们需要对生成器和判别器使用不同的成本函数来更新模型参数。我们可以通过parameters属性获取计算图中函数对象的参数。然而,当更新模型参数时,更新只发生在两个子网络中的一个,另一个没有改变。换句话说,当更新生成器的参数时,我们只更新了G函数的参数,没有更新D函数的参数。

训练模型

训练GAN的代码与2014年神经信息处理系统大会(NIPS)上的一篇论文(链接:https://arxiv.org/pdf/1406.2661v1.pdf)提出的算法非常接近。在实现是,我们训练D来最大化给训练样本和G中生产的样本贴正确标签的概率。换句话说,D和G在玩一个双人针对函数 V(G,D) 极大极小值游戏。

minGmaxDV(D,G)=Ex[logD(x)]+Ez[log(1D(G(z)))]

这个游戏的最优点,生成器将生成非常逼真的数据,判别器预测合成图片的概率将会变成0.5。上面提到的论文中提到的算法会在下面的代码中实现。

image

ef build_graph(noise_shape, image_shape,
                G_progress_printer, D_progress_printer):
    input_dynamic_axes = [C.Axis.default_batch_axis()]
    Z = C.input_variable(noise_shape, dynamic_axes=input_dynamic_axes)
    X_real = C.input_variable(image_shape, dynamic_axes=input_dynamic_axes)
    X_real_scaled = 2*(X_real / 255.0) - 1.0

    # Create the model function for the generator and discriminator models
    X_fake = generator(Z)
    D_real = discriminator(X_real_scaled)
    D_fake = D_real.clone(
        method = 'share',
        substitutions = {X_real_scaled.output: X_fake.output}
    )

    # Create loss functions and configure optimazation algorithms
    G_loss = 1.0 - C.log(D_fake)
    D_loss = -(C.log(D_real) + C.log(1.0 - D_fake))

    G_learner = C.fsadagrad(
        parameters = X_fake.parameters,
        lr = C.learning_rate_schedule(lr, C.UnitType.sample),
        momentum = C.momentum_as_time_constant_schedule(700)
    )
    D_learner = C.fsadagrad(
        parameters = D_real.parameters,
        lr = C.learning_rate_schedule(lr, C.UnitType.sample),
        momentum = C.momentum_as_time_constant_schedule(700)
    )

    # Instantiate the trainers
    G_trainer = C.Trainer(
        X_fake,
        (G_loss, None),
        G_learner,
        G_progress_printer
    )
    D_trainer = C.Trainer(
        D_real,
        (D_loss, None),
        D_learner,
        D_progress_printer
    )

    return X_real, X_fake, Z, G_trainer, D_trainer

随着定义值函数,我们开始对GAN模型进行间接训练。训练这个模型根据硬件状况将会话费很长时间特别是如果你把isFast设为False。

def train(reader_train):
    k = 2

    # print out loss for each model for upto 50 times
    print_frequency_mbsize = num_minibatches // 50
    pp_G = C.logging.ProgressPrinter(print_frequency_mbsize)
    pp_D = C.logging.ProgressPrinter(print_frequency_mbsize * k)

    X_real, X_fake, Z, G_trainer, D_trainer = \
        build_graph(g_input_dim, d_input_dim, pp_G, pp_D)

    input_map = {X_real: reader_train.streams.features}
    for train_step in range(num_minibatches):

        # train the discriminator model for k steps
        for gen_train_step in range(k):
            Z_data = noise_sample(minibatch_size)
            X_data = reader_train.next_minibatch(minibatch_size, input_map)
            if X_data[X_real].num_samples == Z_data.shape[0]:
                batch_inputs = {X_real: X_data[X_real].data, 
                                Z: Z_data}
                D_trainer.train_minibatch(batch_inputs)

        # train the generator model for a single step
        Z_data = noise_sample(minibatch_size)
        batch_inputs = {Z: Z_data}
        G_trainer.train_minibatch(batch_inputs)

        G_trainer_loss = G_trainer.previous_minibatch_loss_average

    return Z, X_fake, G_trainer_loss


reader_train = create_reader(train_file, True, d_input_dim, label_dim=10)

G_input, G_output, G_trainer_loss = train(reader_train)

生成合成图片

现在我们训练好了这个模型,我们能通过简单的给生成器传入随机噪音来创造合成图片病展示他们。下面就生成的图片里的一些随机样本,要看其他照片,你只需要重新运行下面的代码。

def plot_images(images, subplot_shape):
    plt.style.use('ggplot')
    fig, axes = plt.subplots(*subplot_shape)
    for image, ax in zip(images, axes.flatten()):
        ax.imshow(image.reshape(28, 28), vmin = 0, vmax = 1.0, cmap = 'gray')
        ax.axis('off')
    plt.show()

noise = noise_sample(36)
images = G_output.eval({G_input: noise})
plot_images(images, subplot_shape =[6, 6])

image

大量的迭代会生成看起来更像MNIST数据集的图片。一个更好的效果展示如下。
image

注意:要获取真实世界的信号需要通过大量的迭代。即使MNIST是一个非常简单的数据,全连接网络在数据建模方面也非常有效。


欢迎扫码关注我的微信公众号获取最新文章
image

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值