GAN
GAN简介
生成式对抗网络(Generative adversarial networks,GANs)的核心思想源自于零和博弈,包括生成器和判别器两个部分。生成器接收随机变量并生成“假”样本,判别器则用于判断输入的样本是真实的还是合成的。两者通过相互对抗来获得彼此性能的提升。判别器所作的其实就是一个二分类任务,我们可以计算他的损失并进行反向传播求出梯度,从而进行参数更新。
GAN的优化目标可以写作:
min G max D V ( D , G ) = E x ∼ p d a t a [ log D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] \large {\min_G\max_DV(D,G)= \mathbb{E}_{x\sim p_{data}}[\log D(x)]+\mathbb{E}_{z\sim p_z(z)}[log(1-D(G(z)))]} GminDmaxV(D,G)=Ex∼pdata[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
其中 log D ( x ) \log D(x) logD(x)代表了判别器鉴别真实样本的能力,而 D ( G ( z ) ) D(G(z)) D(G(z))则代表了生成器欺骗判别器的能力。在实际的训练中,生成器和判别器采取交替训练,即先训练D,然后训练G,不断往复。
WGAN
在上一部分我们给出了GAN的优化目标,这个目标的本质是在最小化生成样本与真实样本之间的JS距离。但是在实验中发现,GAN的训练非常的不稳定,经常会陷入坍缩模式。这是因为,在高维空间中,并不是每个点都可以表示一个样本,而是存在着大量不代表真实信息的无用空间。当两个分布没有重叠时,JS距离不能准确的提供两个分布之间的差异。这样的生成器,很难“捕捉”到低维空间中的真实数据分布。因此,WGAN(Wasserstein GAN)的作者提出了Wasserstein距离(推土机距离)的概念,其公式可以进行如下表示:
W ( P r , P g ) = inf γ ∼ ∏ P r , P g E ( x , y ) γ [ ∥ x − y ∥ ] W(\mathbb P_r,\mathbb P_g)=\inf_{\gamma\sim\prod{\mathbb P_r,\mathbb P_g}}\mathbb E_{(x,y)~\gamma}[\|x-y\|] W(Pr,Pg)=γ∼∏Pr,PginfE(x,y) γ[∥x−y∥]
这里 ∏ P r , P g \prod{\mathbb P_r,\mathbb P_g} ∏Pr,Pg指的是真实分布 P r \mathbb P_r Pr和生成分布 P g \mathbb P_g Pg的联合分布所构成的集合, ( x , y ) (x,y) (x,y)是从 γ \gamma γ中取得的一个样本。枚举两者之间所有可能的联合分布,计算其中样本间的距离 ∥ x − y ∥ \|x-y\| ∥x−y∥,并取其期望。而Wasserstein距离就是两个分布样本距离期望的下界值。这个简单的改进,使得生成样本在任意位置下都能给生成器带来合适的梯度,从而对参数进行优化。
DCGAN
卷积神经网络近年来取得了耀眼的成绩,展现了其在图像处理领域独特的优势。很自然的会想到,如果将卷积神经网络引入GAN中,是否可以带来效果上的提升呢?DCGAN(Deep Convolutional GANs)在GAN的基础上优化了网络结构,用完全的卷积替代了全连接层,去掉池化层,并采用批标准化(Batch Normalization,BN)等技术,使得网络更容易训练。
用DCGAN生成图像
为了更方便准确的说明DCGAN的关键环节,这里用一个简化版的模型实例来说明。代码基于pytorch深度学习框架,数据集采用MNIST
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
import os
#定义一些超参数
nc = 1 #输入图像的通道数
nz = 100 #输入噪声的维度
num_epochs = 200 #迭代次数
batch_size = 64 #批量大小
sample_dir = 'gan_samples'
# 结果的保存目录
if not os.path.exists(sample_dir):
os.makedirs(sample_dir)
# 加载MNIST数据集
trans = transforms.Compose([
transforms.ToTensor(),transforms.Normalize([0.5], [0.5])])
mnist = torchvision.datasets.MNIST(root=r'G:\VsCode\ml\mnist',
train=True,
transform=trans,
download=False)
data_loader = torch.utils.data.DataLoader(dataset=mnist,
batch_size=batch_size,
shuffle=True)