WGAN简介

WGAN介绍

1. 论文基本信息

(1) 作者:Martin Arjovsky, Soumith Chintala, Leon Bottou
(2) 题目:Wasserstein GAN
(3) 出处:International Conference on Machine Learning(ICML)
(4) 年份:2017

2. 论文背景及针对的问题

生成对抗网络(GAN)是一种无监督的深度学习模型,通过框架中的两个模块:生成模型和判别模型的相互博弈学习来得到较好的输出。但是原始的GAN网络采用的是迭代的方式进行训练,训练过程很困难,通常需要设计启发式算法以及精心设计网络架构,不具有通用性,且生成器与判别器的损失函数无法指示训练过程,同时还存在着生成样本缺乏多样性的问题。该论文主要针对原始GAN网络的这些问题对GAN网络进行了改进,提出了Wasserstein GAN(WGAN)。

3. 论文主要创新点及贡献

针对上述问题,论文中引入了Wasserstein距离来替代JS 散度和KL散度来作为优化目标,基于Wasserstein距离相对于KL散度与JS散度具有优越的平滑特性,从根本上解决了原始GAN的梯度消失问题。同时改进了训练流程,解决了GAN训练不稳定以及模式崩溃的问题,同时提供了可以指示训练进程的数值,网络不再需要精心设计。

4. 论文主要原理及算法介绍

(1) GAN网络基础介绍
GAN包含生成器G和判别器D两个部分,其结构如图所示
在这里插入图片描述
GAN的训练目标为
在这里插入图片描述

(2)WGAN的改进
论文作者分析了GAN训练不稳定问题的来源:原始GAN训练目标的形式等价于在最优判别器下最小化生成分布与真实分布的KL散度及JS散度,但如果两个分布不存在不可忽略的重叠,而这种情况时常出现,其梯度变为常数,导致无法训练。作者引入了Wasserstein距离来替代原始的JS和KL散度,其定义为:
在这里插入图片描述

其中 是两个分布的联合分布,即对每个可能的联合分布,从中采样真实样本x和生成样本y,计算该分布下距离的期望,所有期望的下界即为Wasserstein距离。作者推导得到Wasserstein距离更加平滑,提供了更有意义的梯度。
但实际中Wasserstein距离定义中下界无法求解,作者引入了Lipschitz连续限制,将其形式变换为
在这里插入图片描述

近似模拟Wasserstein距离,其中K即为对应的Lipschitz常数。因此可以构造一个参数为w,最后一层不是非线性激活层的判别器网络f_w,在限制参数w不超过某一范围下,最大化
在这里插入图片描述

与原始GAN的目标函数相比,只是将判别器最后一层sigmoid去掉,不取log,该损失函数可以指示训练的进程,其数值越小,说明真实分布与生成分布的Wasserstein距离越小,训练越好。同时作者在训练过程中还采取了将判别器参数更新后的绝对值截断到一定范围内以及不采用基于动量的优化算法等来优化训练过程,解决了训练不稳定问题。作者在实验部分进行了不少实验验证,证实了WGAN的有效性。

5. 不足

WGAN中采用权重截断的方式来满足Lipschitz连续性条件,但实际操作中采用这种方式可能会导致大部分权重最终都落在上下界限附近,因此神经网络的权重大部分可能只有两个数,极大限制了其拟合能力,导致梯度消失或爆炸。因此后续Ishaan等针对这一问题提出了WGAN-GP,采取了梯度惩罚解决了这一问题。

### WGAN-GP 的架构图与模型结构 WGAN-GP(Wasserstein GAN with Gradient Penalty)是一种改进版的生成对抗网络,在保持 Wasserstein 距离作为损失函数的同时引入梯度惩罚机制来稳定训练过程并提高生成效果[^1]。 #### 判别器与生成器设计 在 WGAN-GP 中,生成器 \(G\) 和判别器 \(D\) 都采用了监督学习的方式进行训练。两者具有对立的目标:生成器试图创建尽可能真实的样本欺骗判别器;而判别器则努力区分真实数据和由生成器产生的伪造品[^2]。 #### 关键组件说明 - **生成器 (Generator)**: 接收随机噪声向量 z 作为输入,并输出合成图像。 - **判别器 (Discriminator)**: 输入可以是来自实际分布的真实图片或是通过生成器得到的人造实例,输出是对该输入属于真伪的概率估计值。 为了确保 Lipschitz 连续性条件成立,从而使得理论上的最优解可达到,WGAN-GP 对原始 WGAN 增加了一个额外项——即所谓的 "gradient penalty" (梯度惩罚),这有助于防止权重裁剪带来的数值不稳定问题,并进一步增强了模型的表现力[^3]。 ```mermaid graph LR; A[Noise Vector z] --> B(Generator); C[Real Image Data] & D[Fake Images from Generator] --> E(Discriminator); F[Wasserstein Loss + Gradient Penalty] -.->|Optimize|E; ``` 此图表展示了基本框架下的交互流程,其中包含了两个主要部分: - 左侧表示生成路径,从随机噪音到假象; - 右侧展示评估环节,涉及真假两类源的数据流经判别模块后的处理逻辑。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值