PyTorch和TensorFlow生成对抗网络学习MNIST数据集

介绍

生成对抗网络(简称GAN)是最近开发的最受欢迎的机器学习算法之一。

对于人工智能(AI)领域的新手,我们可以简单地将机器学习(ML)描述为AI的子领域,它使用数据来“教”机器/程序如何执行新任务。 一个简单的例子就是使用一个人的脸部图像作为算法的输入,以便程序学会在任何给定的图片中识别同一个人(它可能也需要负样本)。 为此,我们可以将机器学习描述为应用数学优化,其中一种算法可以表示多维空间中的数据,然后学习区分新的多维矢量样本是否属于目标分布。

生成对抗网络的魔法

事实证明,他们在建模和生成高维数据方面非常成功,这就是为什么它们如此受欢迎。 然而,它们并不是生成模型的唯一类型,其他类型包括变分自动编码器(VAE),pixelCNN / pixelRNN和real NVP。 每个模型都有其自身的权衡。

一些与GAN最相关的利弊是:

  • 他们目前生成最清晰的图像
  • 它们易于训练(因为不需要统计推断),并且仅需要反向传播即可获得梯度
  • 由于不稳定的训练动态,GAN难以优化
  • 他们无法进行统计推断:GAN属于直接隐式密度模型。他们在没有明确定义概率分布函数的情况下对p(x)进行建模。

生成模型是了解当今围绕我们的大量数据的最有前途的方法之一。 根据OpenAI,能够创建数据的算法可能在本质上更好地理解世界。

生成模型可被认为比其鉴别器包含更多的信息,因为它们也可用于判别任务,例如分类或回归(目标是诸如ℝ的连续值)。 通过对联合概率分布函数进行统计推断,可以计算出此类任务大部分时间所需的条件概率分布函数 p ( y ∣ x ) p(y \mid x) p(yx)

尽管生成模型可用于分类和回归,但是与某些情况下的生成方法相比,完全鉴别方法通常在鉴别任务上更为成功。

案例

在几个用例中,生成模型可以应用于:

  • 生成逼真的艺术品样本(视频/图像/音频)
  • 使用时序数据进行仿真和计划
  • 统计推断
  • 也可用于生成可扩展小型数据集的输入

GAN概述

生成对抗网络由两个模型组成:

  • 第一个模型称为生成器,它旨在生成与预期相似的新数据。生成器可以与人类的赝品相提并论,后者可以伪造艺术品。
  • 第二种模型称为鉴别器。 该模型的目的是识别输入数据是由伪造者生成的“真实”(属于原始数据集)还是“伪造”(fake)。 在这种情况下,鉴别器类似于艺术专家,后者试图将艺术品视为真实或赝品。

GAN数学模型

训练GAN

由于使用神经网络对生成器和鉴别器进行建模,因此可以使用基于梯度的优化算法来训练GAN。 在我们的编码示例中,我们将使用随机梯度下降法,因为事实证明该梯度下降法在多个领域中均已成功完成。

训练GAN的基本步骤可以描述如下:

  1. 采样噪声集和实数集,每个集的大小为m
  2. 在此数据上训练鉴别器
  3. 采样大小为m的另一个噪声子集
  4. 在此数据上训练生成器
  5. 从第1步重复

编程GAN

首先导入必要库

pip install torchvision tensorboardx jupyter matplotlib numpy

导入以下依赖项

import torch
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets

为了记录进度,我们将导入一个我创建的附加文件,使我们可以在控制台/ Jupyter中可视化训练过程,同时将其存储在TensorBoard中,以供那些已经知道如何使用它的人使用。

from utils import Logger

您需要下载文件并将其放在GAN文件所在的文件夹中。您不必了解此文件中的代码,因为它仅用于可视化目的。

数据集

我们将在这里使用MNIST数据集,它由大约60.000个手写数字的黑白图像组成,每个图像的尺寸为28x28像素。 该数据集将根据一些有用的技巧进行预处理,这些技巧被证明对训练GAN很有用。

具体来说,介于[0,255]之间的输入值将在-1和1之间归一化。这意味着值0将被映射为-1,值255被映射为1,并且类似地,介于两者之间的所有值都将得到a。 值在[-1,1]范围内。

网络

接下来,我们将从鉴别器开始定义神经网络。 该网络将以扁平化的图像作为输入,并返回其属于真实数据集或合成数据集的概率。 每个图像的输入大小将为28x28 = 784。 关于该网络的结构,它将具有三个隐藏层,每个隐藏层后面是Leaky-ReLU非线性和一个Dropout层,以防止过度拟合。将Sigmoid / Logistic函数应用于实值输出,以获取开放范围(0,1)中的值:

我们还需要一些其他功能,这些功能允许我们将扁平化的图像转换为二维表示,而另一种则相反。

另一方面,生成网络将潜变量向量作为输入,并返回784值向量,该向量对应于扁平化的28x28图像。 请记住,该网络的目的是学习如何创建手写数字的无法区别的图像,这就是为什么其输出本身就是新图像的原因。

该网络将具有三个隐藏层,每个隐藏层之后是Leaky-ReLU非线性。 输出层将具有TanH激活函数,该函数将结果值映射到(-1,1)范围内,该范围与我们预处理的MNIST图像所界定的范围相同。

我们还需要一些其他功能,以允许我们创建随机噪声。随机噪声将从此链接中提出的均值0和方差1的正态分布中采样。

def noise(size):
    '''
    Generates a 1-d vector of gaussian sampled random values
    '''
    n = Variable(torch.randn(size, 100))
    return n

优化

结果

最初生成的图像是纯噪声:

但是后来他们改进了,

在获得不错的合成图像之前,

也可以可视化学习过程。 正如您在下图中所看到的,开始时鉴别器错误非常高,因为它不知道如何正确地将图像分类为真实还是伪造。 当鉴别器变得更好并且其误差在步骤5k减小到约0.5时,生成器误差增加,证明了鉴别器的性能优于生成器,并且可以正确地对假样本进行分类。 随着时间的流逝和训练的继续,生成器误差会降低,这意味着生成的图像越来越好。 随着生成器的改进,鉴别器的误差也会增加,因为合成图像每次都变得越来越逼真。

生成器随时间的错误

鉴别器随时间的错误

我已经介绍了生成对抗网络。 我们首先了解它们是哪种算法,以及为什么它们如此重要。 接下来,我们探索了符合GAN的部分以及它们如何协同工作。 最终,我们通过编程并使用GAN的完全有效的实现进行编程,从而将理论与实践联系起来,该实现学会了创建MNIST数据集的综合示例。

本文源码

详情参阅 - 亚图跨际

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值