生成式对抗网络(GANs)综述

本文深入探讨了生成式对抗网络(GANs),包括其基本原理、Wasserstein GAN (WGAN)、Deep Convolutional GANs (DCGAN)、条件生成式对抗网络(cGAN)、Pix2Pix、CycleGAN以及StarGAN。这些模型在图像生成、图像到图像翻译等领域展现出强大的能力。通过对各种GAN变种的分析,展示了它们在解决不同问题时的创新之处。
摘要由CSDN通过智能技术生成

GAN

GAN简介

生成式对抗网络(Generative adversarial networks,GANs)的核心思想源自于零和博弈,包括生成器和判别器两个部分。生成器接收随机变量并生成“假”样本,判别器则用于判断输入的样本是真实的还是合成的。两者通过相互对抗来获得彼此性能的提升。判别器所作的其实就是一个二分类任务,我们可以计算他的损失并进行反向传播求出梯度,从而进行参数更新。

GAN的优化目标可以写作:
min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p d a t a [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] \large {\min_G\max_DV(D,G)= \mathbb{E}_{x\sim p_{data}}[\log D(x)]+\mathbb{E}_{z\sim p_z(z)}[log(1-D(G(z)))]} GminDmaxV(D,G)=Expdata[logD(x)]+Ezpz(z)[log(1D(G(z)))]
其中 log ⁡ D ( x ) \log D(x) logD(x)代表了判别器鉴别真实样本的能力,而 D ( G ( z ) ) D(G(z)) D(G(z))则代表了生成器欺骗判别器的能力。在实际的训练中,生成器和判别器采取交替训练,即先训练D,然后训练G,不断往复。

WGAN

在上一部分我们给出了GAN的优化目标,这个目标的本质是在最小化生成样本与真实样本之间的JS距离。但是在实验中发现,GAN的训练非常的不稳定,经常会陷入坍缩模式。这是因为,在高维空间中,并不是每个点都可以表示一个样本,而是存在着大量不代表真实信息的无用空间。当两个分布没有重叠时,JS距离不能准确的提供两个分布之间的差异。这样的生成器,很难“捕捉”到低维空间中的真实数据分布。因此,WGAN(Wasserstein GAN)的作者提出了Wasserstein距离(推土机距离)的概念,其公式可以进行如下表示:
W ( P r , P g ) = inf ⁡ γ ∼ ∏ P r , P g E ( x , y )   γ [ ∥ x − y ∥ ] W(\mathbb P_r,\mathbb P_g)=\inf_{\gamma\sim\prod{\mathbb P_r,\mathbb P_g}}\mathbb E_{(x,y)~\gamma}[\|x-y\|] W(Pr,Pg)=γPr,PginfE(x,y) γ[xy]
这里 ∏ P r , P g \prod{\mathbb P_r,\mathbb P_g} Pr,Pg指的是真实分布 P r \mathbb P_r Pr和生成分布 P g \mathbb P_g Pg的联合分布所构成的集合, ( x , y ) (x,y) (x,y)是从 γ \gamma γ中取得的一个样本。枚举两者之间所有可能的联合分布,计算其中样本间的距离 ∥ x − y ∥ \|x-y\| xy,并取其期望。而Wasserstein距离就是两个分布样本距离期望的下界值。这个简单的改进,使得生成样本在任意位置下都能给生成器带来合适的梯度,从而对参数进行优化。

DCGAN

卷积神经网络近年来取得了耀眼的成绩,展现了其在图像处理领域独特的优势。很自然的会想到,如果将卷积神经网络引入GAN中,是否可以带来效果上的提升呢?DCGAN(Deep Convolutional GANs)在GAN的基础上优化了网络结构,用完全的卷积替代了全连接层,去掉池化层,并采用批标准化(Batch Normalization,BN)等技术,使得网络更容易训练。

用DCGAN生成图像

为了更方便准确的说明DCGAN的关键环节,这里用一个简化版的模型实例来说明。代码基于pytorch深度学习框架,数据集采用MNIST

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
import os 
#定义一些超参数
nc = 1    				#输入图像的通道数
nz = 100   				#输入噪声的维度
num_epochs = 200		#迭代次数
batch_size = 64			#批量大小
sample_dir = 'gan_samples'
# 结果的保存目录
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
# 加载MNIST数据集
trans = transforms.Compose([
                transforms.ToTensor(),transforms.Normalize([0.5], [0.5])])
mnist = torchvision.datasets.MNIST(root=r'G:\VsCode\ml\mnist',
                                   train=True,
                                   transform=trans,
                                   download=False)
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size, 
                                          shuffle=True)
判别器&生成器
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值