介绍
生成对抗网络(简称GAN)是最近开发的最受欢迎的机器学习算法之一。
对于人工智能(AI)领域的新手,我们可以简单地将机器学习(ML)描述为AI的子领域,它使用数据来“教”机器/程序如何执行新任务。 一个简单的例子就是使用一个人的脸部图像作为算法的输入,以便程序学会在任何给定的图片中识别同一个人(它可能也需要负样本)。 为此,我们可以将机器学习描述为应用数学优化,其中一种算法可以表示多维空间中的数据,然后学习区分新的多维矢量样本是否属于目标分布。
生成对抗网络的魔法
事实证明,他们在建模和生成高维数据方面非常成功,这就是为什么它们如此受欢迎。 然而,它们并不是生成模型的唯一类型,其他类型包括变分自动编码器(VAE),pixelCNN / pixelRNN和real NVP。 每个模型都有其自身的权衡。
一些与GAN最相关的利弊是:
- 他们目前生成最清晰的图像
- 它们易于训练(因为不需要统计推断),并且仅需要反向传播即可获得梯度
- 由于不稳定的训练动态,GAN难以优化
- 他们无法进行统计推断:GAN属于直接隐式密度模型。他们在没有明确定义概率分布函数的情况下对p(x)进行建模。
生成模型是了解当今围绕我们的大量数据的最有前途的方法之一。 根据OpenAI,能够创建数据的算法可能在本质上更好地理解世界。
生成模型可被认为比其鉴别器包含更多的信息,因为它们也可用于判别任务,例如分类或回归(目标是诸如ℝ的连续值)。 通过对联合概率分布函数进行统计推断,可以计算出此类任务大部分时间所需的条件概率分布函数 p ( y ∣ x ) p(y \mid x) p(y∣x)。
尽管生成模型可用于分类和回归,但是与某些情况下的生成方法相比,完全鉴别方法通常在鉴别任务上更为成功。
案例
在几个用例中,生成模型可以应用于:
- 生成逼真的艺术品样本(视频/图像/音频)
- 使用时序数据进行仿真和计划
- 统计推断
- 也可用于生成可扩展小型数据集的输入
GAN概述
生成对抗网络由两个模型组成:
- 第一个模型称为生成器,它旨在生成与预期相似的新数据。生成器可以与人类的赝品相提并论,后者可以伪造艺术品。
- 第二种模型称为鉴别器。 该模型的目的是识别输入数据是由伪造者生成的“真实”(属于原始数据集)还是“伪造”(fake)。 在这种情况下,鉴别器类似于艺术专家,后者试图将艺术品视为真实或赝品。
GAN数学模型
训练GAN
由于使用神经网络对生成器和鉴别器进行建模,因此可以使用基于梯度的优化算法来训练GAN。 在我们的编码示例中,我们将使用随机梯度下降法,因为事实证明该梯度下降法在多个领域中均已成功完成。
训练GAN的基本步骤可以描述如下:
- 采样噪声集和实数集,每个集的大小为m
- 在此数据上训练鉴别器
- 采样大小为m的另一个噪声子集
- 在此数据上训练生成器
- 从第1步重复
编程GAN
首先导入必要库
pip install torchvision tensorboardx jupyter matplotlib numpy
导入以下依赖项
import torch
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets
为了记录进度,我们将导入一个我创建的附加文件,使我们可以在控制台/ Jupyter中可视化训练过程,同时将其存储在TensorBoard中,以供那些已经知道如何使用它的人使用。
from utils import Logger
您需要下载文件并将其放在GAN文件所在的文件夹中。您不必了解此文件中的代码,因为它仅用于可视化目的。
数据集
我们将在这里使用MNIST数据集,它由大约60.000个手写数字的黑白图像组成,每个图像的尺寸为28x28像素。 该数据集将根据一些有用的技巧进行预处理,这些技巧被证明对训练GAN很有用。
具体来说,介于[0,255]之间的输入值将在-1和1之间归一化。这意味着值0将被映射为-1,值255被映射为1,并且类似地,介于两者之间的所有值都将得到a。 值在[-1,1]范围内。
网络
接下来,我们将从鉴别器开始定义神经网络。 该网络将以扁平化的图像作为输入,并返回其属于真实数据集或合成数据集的概率。 每个图像的输入大小将为28x28 = 784。 关于该网络的结构,它将具有三个隐藏层,每个隐藏层后面是Leaky-ReLU非线性和一个Dropout层,以防止过度拟合。将Sigmoid / Logistic函数应用于实值输出,以获取开放范围(0,1)中的值:
我们还需要一些其他功能,这些功能允许我们将扁平化的图像转换为二维表示,而另一种则相反。
另一方面,生成网络将潜变量向量作为输入,并返回784值向量,该向量对应于扁平化的28x28图像。 请记住,该网络的目的是学习如何创建手写数字的无法区别的图像,这就是为什么其输出本身就是新图像的原因。
该网络将具有三个隐藏层,每个隐藏层之后是Leaky-ReLU非线性。 输出层将具有TanH激活函数,该函数将结果值映射到(-1,1)范围内,该范围与我们预处理的MNIST图像所界定的范围相同。
我们还需要一些其他功能,以允许我们创建随机噪声。随机噪声将从此链接中提出的均值0和方差1的正态分布中采样。
def noise(size):
'''
Generates a 1-d vector of gaussian sampled random values
'''
n = Variable(torch.randn(size, 100))
return n
优化
结果
最初生成的图像是纯噪声:
但是后来他们改进了,
在获得不错的合成图像之前,
也可以可视化学习过程。 正如您在下图中所看到的,开始时鉴别器错误非常高,因为它不知道如何正确地将图像分类为真实还是伪造。 当鉴别器变得更好并且其误差在步骤5k减小到约0.5时,生成器误差增加,证明了鉴别器的性能优于生成器,并且可以正确地对假样本进行分类。 随着时间的流逝和训练的继续,生成器误差会降低,这意味着生成的图像越来越好。 随着生成器的改进,鉴别器的误差也会增加,因为合成图像每次都变得越来越逼真。
生成器随时间的错误
鉴别器随时间的错误
我已经介绍了生成对抗网络。 我们首先了解它们是哪种算法,以及为什么它们如此重要。 接下来,我们探索了符合GAN的部分以及它们如何协同工作。 最终,我们通过编程并使用GAN的完全有效的实现进行编程,从而将理论与实践联系起来,该实现学会了创建MNIST数据集的综合示例。
本文源码
详情参阅 - 亚图跨际