收敛速度更快更稳定的Wasserstein GAN(WGAN)

生成对抗网络(GANs)是一种很有力的生成模型,它解决生成建模问题的方式就像在两个对抗式网络中进行比赛:给出一些噪声源,生成器网络能够产生合成的数据,鉴别器网络在真实数据和生成器的输出中进行鉴别。GAN能够产生十分生动的样例,但是很难训练。尽管最近大量的研究工作都投入到了寻找能让GAN稳定训练的方法上,但GAN的持续稳定训练成了依然是一个公开的问题。
概要最近提出的Wasserstein GAN(WGAN)在GAN的稳定训练上取得了重大进展,但是依然会产生低质量的样例或者出现在某些设置上不能收敛的情况。会产生这些训练失败的样例一般都是因为,作者通过在WGAN中使用权重修剪来达到在鉴别器中强制利普希茨(Lipschitz)限制条件的目的,但这样的方式会导致病态行为。
根据上述情况,文章提出了一种替代方法来强制Lipschitz限制条件:不修剪权重,而是根据输入来惩罚鉴别器的梯度正则项。这一方法方法与权重修剪WGAN相比,收敛的更快并且能产生更高质量的样例。这一替代方法基本能够保证很稳定的GAN训练。这是自深度学习发展以来,第一次可以训练多种多样的GAN结构,并且几乎不用进行超参数的调整,包括101层的ResNet和离散数据上的语言模型。
本文将从提出背景,算法介绍,实验结果,论文实现等四个方面来介绍这篇文章。
提出背景作者发现,WGAN的权重修剪会导致优化困难,并且即使能够优化成功,鉴别器也会出现病态等值面。作者测试了了WGAN中的权重限制条件(对每个权重的量级进行硬修剪),也测试了其他的权重限制条件(L2范数修剪,权重正则化等等),以及一些软限制条件(L1和L2权重衰减等等),实验结果发现他们都会出现类似的问题。
容量使用不当
作者在玩具数据库上使用权重修剪来训练WGAN鉴别器使其达到最优,保持生成器的分布固定,以及单元变量高斯噪声,对鉴别器的等值面曲线图进行了绘制。我们省略了鉴别器中的批正则化(batch normalization)。在每个样例中,作者发现,用权重修剪训练的鉴别器忽略了数据分布的高阶矩,而是对很简单的近似建模来进行优化函数。相比而言,梯度惩罚的方法不会因为这样的行为受到影响。


梯度消失和爆炸作者发现,WGAN的优化过程之所以很难,是由于权重修剪和损失函数之间的互动。这一情况会不可避免的导致梯度消失或梯度爆炸,取决于修剪的阈值。如果权重由于限制变得太小,梯度就会在反向传播到之前的层时消失。这会阻止鉴别器(和生成器)之前的层接受有用的训练信号,并且会使得深度网络学习速度变得很慢。


(a)深度WGAN的鉴别器的梯度正则项在玩具数据库上训练时的变化情况。用权重修剪的WGAN中的梯度总是爆炸或消失,而作者提出的梯度惩罚方法则为之前层提供了稳定梯度。
(b)分别使用权重修剪(左)和使用梯度惩罚(右)的WGAN的权重直方图。权重修剪将权重推到了修剪范围的极限,当这个范围很高时,就会导致梯度爆炸,然后减慢训练速度。
算法——梯度惩罚由于WGAN中的权重修剪带来的不好的结果,作者提出了一种替代方法,在训练目标上加强Lipschitz限制条件:当且仅当一个可微函数梯度的正则项处处小于等于1时,它才满足1-Lipschitz条件。所以作者直接根据输入来限制鉴别器的梯度正则项,也就是梯度惩罚。
分以下几个步骤实现:
·根据直线采样。
·超参数:梯度惩罚引入了一个参数λ,实验中设为10。
·去掉鉴别器的批正则化。
·使用Adam参数设置。
·双面惩罚。
·二次惩罚。
实验结果与权重修剪WGAN相比,梯度惩罚WGAN不仅提高了收敛速度,还提升了网络训练的稳定性。
CIFAR-10 训练速度和样例质量


四个模型在CIFAR-10数据集上的测试分数与生成器迭代次数(左)和系统时间(右)的曲线图。四个模型分别为:权重修剪WGAN,梯度惩罚以及RMSProp(控制优化器)WGAN,梯度惩罚以及Adam参数设置WGAN和DCGAN。从图中可以看出,即使在同样的学习速率下,梯度惩罚方法的表现比权重修剪有显著的提高。DCGAN收敛的更快,但是使用梯度惩罚的WGAN达到相似的分数时,稳定性提高了。
LSUN卧室数据集作者在LSUN卧室数据集上训练了多种GAN模型。除了作为基准的DCGAN,还选择了六个比较难训练的结构:
生成器没有批正则化以及的连续数量滤波
1.4层的512维ReLU MLP生成器
2.生成器和鉴别器中都没有正则化
3.门控相乘非线性
4.Tanh非线性
5.101层ResNet 生成器和鉴别器


从图中可以看出,使用不同方法训练的GAN结构,只有作者提出的使用梯度惩罚WGAN方法在每个结构的训练中都成功了。
字符级语言建模


上部:在Billion Word数据集上测试的WGAN字符级语言模型样例,缩短到32个字符。
底部:使用标准GAN训练的同样结构的模型样例。
据我们所知,这是第一个完全用对抗式训练的语言生成模型,而不需要相似性最大化损失的监督。从图中可以看出它会产生很多拼写错误,但是依然能成功的学习到自然语言统计学的很多数据。
损失曲线和过拟合检测


(a)作者的模型在LSUN卧室数据集上的鉴别器损失,随着网络的训练收敛到最小值。
(b)在1000位MNIST子数据集上的WGAN训练和验证损失。可以看出使用我们的方法(左)或权重修剪法(右)都会产生过拟合。我们的方法中,鉴别器比生成器过拟合速度更快,是的训练损失随时间逐渐增加,而验证损失随之减少。
论文地址及实现
论文链接:https://arxiv.org/abs/1704.00028
github链接:https://github.com/igul222/improved_wgan_training
环境要求:
·Python, NumPy, TensorFlow,SciPy, Matplotlib
·NVIDIA GPU
模型
所有模型的配置在文件最上面的一列常数中已经进行了特别说明。有两个模型应该可以直接使用:
·python gan_toy.py: 玩具数据库(8 Gaussians, 25 Gaussians, SwissRoll).
·python gan_mnist.py: MNIST手写数字识别库
对于其他的模型,在运行之前请修改脚本,指定数据集的路径为DATA_DIR下。每个模型的数据集现在都已公开;下载链接已包含在文件中。
·python gan_64x64.py: 64x64 结构 (论文中,该模型是在ImageNet上进行训练的,而不是LSUN卧室图片库)
·python gan_language.py: 字符级别的语言模型
·python gan_cifar.py: CIFAR-10数据库
AIJob社是《全球人工智能》旗下专门为AI开发工程师免费服务的求职平台。我们将竭尽全力帮助每一个ai工程师对接自己喜欢的企业,推荐给你喜欢的直接领导,帮你谈一个最好的薪资待遇。
微信咨询:aihr007简历投递:hr@top25.cn企业合作:job@top25.cn
《全球人工智能》招聘5名兼职翻译及10名兼职VIP社群专家:图像技术、语音技术、自然语言、机器学习、数据挖掘等专业技术领域,工作内容及待遇请在公众号内回复“兼职+个人微信号”联系工作人员。
热门文章推荐
重磅|全球AI报告:看看谷歌|苹果|Facebook等十几家巨头都收购了哪些牛逼的AI公司?
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值