模型简介
生成式对抗网络(Generative Adversarial Networks,GAN)是一种生成式机器学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。
最初,GAN由Ian J. Goodfellow于2014年发明,并在论文Generative Adversarial Nets中首次进行了描述,其主要由两个不同的模型共同组成——生成器(Generative Model)和判别器(Discriminative Model):
- 生成器的任务是生成看起来像训练图像的“假”图像;
- 判别器需要判断从生成器输出的图像是真实的训练图像还是虚假的图像。
GAN通过设计生成模型和判别模型这两个模块,使其互相博弈学习产生了相当好的输出。
GAN模型的核心在于提出了通过对抗过程来估计生成模型这一全新框架。在这个框架中,将会同时训练两个模型——捕捉数据分布的生成模型 G G G 和估计样本是否来自训练数据的判别模型 D D D 。
在训练过程中,生成器会不断尝试通过生成更好的假图像来骗过判别器,而判别器在这过程中也会逐步提升判别能力。这种博弈的平衡点是,当生成器生成的假图像和训练数据图像的分布完全一致时,判别器拥有50%的真假判断置信度。
用 x x x 代表图像数据,用 D ( x ) D(x) D(x)表示判别器网络给出图像判定为真实图像的概率。在判别过程中, D ( x ) D(x) D(x) 需要处理作为二进制文件的大小为 1 × 28 × 28 1\times 28\times 28 1×28×28 的图像数据。当 x x x 来自训练数据时, D ( x ) D(x) D(x) 数值应该趋近于 1 1 1 ;而当 x x x 来自生成器时, D ( x ) D(x) D(x) 数值应该趋近于 0 0 0 。因此 D ( x ) D(x) D(x) 也可以被认为是传统的二分类器。
用 z z z 代表标准正态分布中提取出的隐码(隐向量),用 G ( z ) G(z) G(z):表示将隐码(隐向量) z z z 映射到数据空间的生成器函数。函数 G ( z ) G(z) G(z) 的目标是将服从高斯分布的随机噪声 z z z 通过生成网络变换为近似于真实分布 p d a t a ( x ) p_{data}(x) pdata(x) 的数据分布,我们希望找到 θ θ θ 使得 p G ( x ; θ ) p_{G}(x;\theta) pG(x;θ) 和 p d a t a ( x ) p_{data}(x) pdata(x) 尽可能的接近,其中 θ \theta θ 代表网络参数。
D ( G ( z ) ) D(G(z)) D(G(z)) 表示生成器 G G G 生成的假图像被判定为真实图像的概率,如Generative Adversarial Nets中所述, D D D 和 G G G 在进行一场博弈, D D D 想要最大程度的正确分类真图像与假图像,也就是参数 log D ( x ) \log D(x) logD(x);而 G G G 试图欺骗 D D D 来最小化假图像被识别到的概率,也就是参数 log ( 1 − D ( G ( z ) ) ) \log(1−D(G(z))) log(1−D(G(z)))。因此GAN的损失函数为:
min G max D V ( D , G ) = E x ∼ p d a t a ( x ) [ log D ( x ) ] + E z ∼ p z ( z ) [ log ( 1 − D ( G ( z ) ) ) ] \min\limits_{G}\max\limits_{D} V(D,G)=E_{x\sim p_{data}\;\,(x)}[\log D(x)]+E_{z\sim p_{z}\,(z)}[\log (1-D(G(z)))] GminDmaxV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
从理论上讲,此博弈游戏的平衡点是 p G ( x ; θ ) = p d a t a ( x ) p_{G}(x;\theta) = p_{data}(x) pG(x;θ)=pdata(x),此时判别器会随机猜测输入是真图像还是假图像。下面我们简要说明生成器和判别器的博弈过程:
- 在训练刚开始的时候,生成器和判别器的质量都比较差,生成器会随机生成一个数据分布。
- 判别器通过求取梯度和损失函数对网络进行优化,将靠近真实数据分布的数据判定为1,将靠近生成器生成出来数据分布的数据判定为0。
- 生成器通过优化,生成出更加贴近真实数据分布的数据。
- 生成器所生成的数据和真实数据达到相同的分布,此时判别器的输出为1/2。
在上图中,蓝色虚线表示判别器,黑色虚线表示真实数据分布,绿色实线表示生成器生成的虚假数据分布, z z z 表示隐码, x x x 表示生成的虚假图像 G ( z ) G(z) G(z)。该图片来源于Generative Adversarial Nets。详细的训练方法介绍见原论文。
数据集
数据集简介
MNIST手写数字数据集是NIST数据集的子集,共有70000张手写数字图片,包含60000张训练样本和10000张测试样本,数字图片为二进制文件,图片大小为28*28,单通道。图片已经预先进行了尺寸归一化和中心化处理。
本案例将使用MNIST手写数字数据集来训练一个生成式对抗网络,使用该网络模拟生成手写数字图片。
数据集下载
使用download
接口下载数据集,并将下载后的数据集自动解压到当前目录下。数据下载之前需要使用pip install download
安装download
包。
下载解压后的数据集目录结构如下:
./MNIST_Data/
├─ train
│ ├─ train-images-idx3-ubyte
│ └─ train-labels-idx1-ubyte
└─ test
├─ t10k-images-idx3-ubyte
└─ t10k-labels-idx1-ubyte