史上第二简单的WGAN的pytorch实现

1. GAN的缺点

上一篇讲了GAN,其实代码中展示的那种生成器G的表达式,有一个很大的缺点,就是梯度消失严重。其实除了上一篇文章写的那种表达式,还有另一种方法,但是第二种方法也面临梯度不稳定和模式崩塌的问题。具体原因建议移步哔哩哔哩找李宏毅老师。
基础GAN的生成器和判别器损失迭代10000次数据如下面几张图所示。判别器一开始分数很高接近1,因为他能很轻易分辨,生成器一开始分数很低,因为很难生成符合条件的分布。
但是随着迭代次数不断增加,二者都趋近0.5,难舍难分,说明达到了最优状态。判别器已经分不出哪些是真实数据哪些是生成数据了,下面三张图都大概在接近10000次才能达到最优。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

2. WGAN的改进

上述原因和kl散度,js散度,sigmod函数有关。未来更好的度量两个分布之间的距离,WGAN讲原始的判别器改为了计算两个分布之间的推土机距离,具体算法如下:
在这里插入图片描述
其实改进只有四点:

  1. 判别器最后一层去掉sigmod
D = nn.Sequential( #定义判别器
    nn.Linear(2,64),
    nn.ReLU(),
    nn.Linear(64,1),
    #nn.Sigmoid()
)
  1. 生成器和判别器的loss不取log
G_loss = -torch.mean(pro_atrist1)
D_loss = -torch.mean(pro_atrist0-pro_atrist1) 
  1. 每次更新D的参数前进行梯度裁剪截断到固定常数c
for p in D.parameters():
     p.data.clamp_(-0.01, 0.01)
  1. 不用动量的优化算法,推荐RMSprop或者SGD
optimizer_G = torch.optim.RMSprop(G.parameters(),lr=0.0001) #定义生成器优化函数
optimizer_D = torch.optim.RMSprop(D.parameters(),lr=0.0001) #定义判别器优化函数

3. 实验效果

下面是WGAN的效果,可以看出在4000次左右判别器就无法进行区分了,对比原始的GAN有很大的提升。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

平什么阿

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值