1. GAN的缺点
上一篇讲了GAN,其实代码中展示的那种生成器G的表达式,有一个很大的缺点,就是梯度消失严重。其实除了上一篇文章写的那种表达式,还有另一种方法,但是第二种方法也面临梯度不稳定和模式崩塌的问题。具体原因建议移步哔哩哔哩找李宏毅老师。
基础GAN的生成器和判别器损失迭代10000次数据如下面几张图所示。判别器一开始分数很高接近1,因为他能很轻易分辨,生成器一开始分数很低,因为很难生成符合条件的分布。
但是随着迭代次数不断增加,二者都趋近0.5,难舍难分,说明达到了最优状态。判别器已经分不出哪些是真实数据哪些是生成数据了,下面三张图都大概在接近10000次才能达到最优。
2. WGAN的改进
上述原因和kl散度,js散度,sigmod函数有关。未来更好的度量两个分布之间的距离,WGAN讲原始的判别器改为了计算两个分布之间的推土机距离,具体算法如下:
其实改进只有四点:
- 判别器最后一层去掉sigmod
D = nn.Sequential( #定义判别器
nn.Linear(2,64),
nn.ReLU(),
nn.Linear(64,1),
#nn.Sigmoid()
)
- 生成器和判别器的loss不取log
G_loss = -torch.mean(pro_atrist1)
D_loss = -torch.mean(pro_atrist0-pro_atrist1)
- 每次更新D的参数前进行梯度裁剪截断到固定常数c
for p in D.parameters():
p.data.clamp_(-0.01, 0.01)
- 不用动量的优化算法,推荐RMSprop或者SGD
optimizer_G = torch.optim.RMSprop(G.parameters(),lr=0.0001) #定义生成器优化函数
optimizer_D = torch.optim.RMSprop(D.parameters(),lr=0.0001) #定义判别器优化函数
3. 实验效果
下面是WGAN的效果,可以看出在4000次左右判别器就无法进行区分了,对比原始的GAN有很大的提升。