wgan 不理解 损失函数_【GAN-8】WGAN-Gradient Penalty

论文

Improved Training of Wasserstein GANs​xxx.itp.ac.cn

我们之前说了,WGAN的(启发式的)保证函数

的方法是让
的参数
满足

这一看就是很扯淡的方法,这篇文章则是对这个的改进。

先说说有什么问题。在GAN-GP这篇论文中,作者给出了WGAN的两个主要缺点,同时用了一个toy example说明这些问题。

作者发现不仅是原文中的直接对

clip,同时,对
的L2 norm clip,soft的约束
的L1,L2 norm,等等,都有这些问题。

总之一句话,直接对

下手就是不行。

Capacity underuse

这是容易理解的,毕竟你把

约束在了一个很小的范围内,模型的容量自然很难得到保证。

作者们的toy example的大致思想是,把

都定下来,其中就是
的基础之上加一些噪声。

分别是8个Gaussians,25个Gaussians和Swiss Roll数据集,总之就是三个确定的分布。

214851ee964287fc6064326c925a1dc9.png

上图中的第一排是WGAN中critic(其实就是discriminator,他们换了个名字)的值的图像,下图的则是WGAN-GP的,很容易看出WGAN的模型复杂度确实有影响,WGAN-PG要看起来好得多。

Exploding and vanishing gradients

这同样是直接对

约束带来的后果,作者尝试了WGAN不同的clip画出来的梯度的norm。

注意,随着层数越靠近输入层,norm的波动应该越大,毕竟梯度是反着来的。

39779bb3f9048d9e4c346d270a8fc6ea.png

上图说明了WGAN的梯度不是爆炸就是消失。

当然在GAN中一般都使用了batch normalization的技术,梯度的波动不会这么剧烈,但是WGAN的性能可能会受到影响。

当然,原始的WGAN还有一个缺点,就是实际上根本不能保证clip的函数

是1-Lipschitz的
,那WGAN的W就无从谈起了。

作者的意思是,既然我们想让

满足1-Lipschitz,而1-Lipschitz可以看作
梯度处处小于1,那么我们为何不直接加这个约束呢?

于是他们提出了Gradient penalty,这就是算法名字中GP的由来。

于是,现在的损失函数形如

显然那个Our gradient penalty比较有说道,下面的是对这个公式的具体说明。

Sampling distribution

是什么呢?

我们当然希望

是整个空间的均匀分布,这样能保证
处处都是1-Lipschitz的,但是这实际上是不现实的。而我们实际上也只要保证
之间的点
满足这个性质就可以了。

因此我们从

采样一个点,再从
采样一个点,这样形成了一条
线段,然后从这个线段上采样。

Penalty coefficient

这是个超参数,经验上取

就好了。

No critic batch normalization

这个也显然,用了batch normalization还咋Gradient penalty嘛。

Two-sided penalty

这个比较重要,既然我们希望梯度处处小于1,为什么不做单边约束,也就是

这是因为实际上,EM距离表示为

而函数

的要求虽然是对任意

但是往往是取等号的,毕竟要求极值

嘛。

因此这里就启发式的写成Two-sided penalty而不是One-sided penalty了。

当然这只是经验上的改进。


算法长这样

e73b133351ae622e560cbfb4889256c9.png

和WGAN的一个小区别是WGAN-GP用了adam做优化,而WGAN用的是RMSprop,不过这是细节啦。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
DCGAN(Deep Convolutional Generative Adversarial Network)和WGAN-GP(Wasserstein GAN with Gradient Penalty)是两种常用的生成对抗网络模型。它们的损失函数公式如下: DCGAN损失函数公式: DCGAN使用了两个网络:生成器(Generator)和判别器(Discriminator)。生成器的目标是生成逼真的样本,而判别器的目标是区分真实样本和生成样本。DCGAN损失函数包括两部分:生成器损失和判别器损失。 1. 生成器损失: 生成器损失使用了交叉熵损失函数,表示生成样本被判别为真实样本的概率的负对数: L_G = -log(D(G(z))) 其中,G(z)表示生成器生成的样本,D表示判别器,z表示生成器的输入噪声。 2. 判别器损失: 判别器损失也使用了交叉熵损失函数,表示真实样本被判别为真实样本的概率和生成样本被判别为生成样本的概率的负对数之和: L_D = -log(D(x)) - log(1 - D(G(z))) 其中,x表示真实样本。 WGAN-GP损失函数公式: WGAN-GP是对Wasserstein GAN进行改进的模型,引入了梯度惩罚(Gradient Penalty)来解决原始WGAN的训练不稳定问题。WGAN-GP损失函数包括三部分:生成器损失、判别器损失和梯度惩罚项。 1. 生成器损失: 生成器损失与DCGAN相同,使用了交叉熵损失函数: L_G = -log(D(G(z))) 2. 判别器损失: 判别器损失也与DCGAN相同,使用了交叉熵损失函数: L_D = -log(D(x)) - log(1 - D(G(z))) 3. 梯度惩罚项: 梯度惩罚项是WGAN-GP的关键改进,用于约束判别器的梯度。它通过计算真实样本和生成样本之间的差异,并对差异进行惩罚。梯度惩罚项的计算公式如下: L_GP = λ * ∥∇D(εx + (1-ε)G(z))∥₂ - 1∥² 其中,ε是从[0, 1]均匀采样的随机数,λ是梯度惩罚系数。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值