WGAN:优化不完美的GAN

引言

​   前不久做了初代GAN的实验,感受到了生成模型的强大,最近在看GAN的的变体WGAN感觉到了数学的强大,仅仅在原始的GAN上稍作修改就能达到不一样的效果,真实的感受到了数学是的魅力,本次也是接着李宏毅老师的课件对WGAN的进行了一些整理。当然,非常推荐大家先看篇博文, 然后,读一读作者的paper,多读几遍,总会有醍醐灌顶的效果。不过,首先来看一下WGAN解决了那些问题,以便于我们很好理解其中的点。

we empirically show that WGANs cure the main training problems of GANs. In particular, training WGANs does not require maintaining a careful balance in training of the discriminator and the generator, and does not require a careful design of the network architecture either. The mode dropping phenomenon that is typical in GANs is also drastically reduced. One of the most compelling practical benefits of WGANs is the ability to continuously estimate the EM distance by training the discriminator to optimality. Plotting these learning curves is not only useful for debugging and hyper- parameter searches, but also correlate remarkably well with the observed sample quality.

原始GAN的问题

真实数据分布和生成数据分布很少重叠

​    原始GAN使用JS divergence来衡量 P G P_G PG P d a t a P_{data} Pdata的有多像,但是在很多情况下 P G P_G PG P d a t a P_{data} Pdata是不重叠的,如 P G P_G PG P d a t a P_{data} Pdata和是高维空间的低纬流体,这使得它们之间即使存在有相互重叠的部分,也在很大程度上可以忽略。例如:二维空间中两条相交的线段,它们在重叠在相交点的那一部分对于整体来说是微不足道的。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-281sV0KB-1573387985078)(Imgs/JS不太好.png)]

JS divergence本身的问题

​   如果 P G P_G PG P d a t a P_{data} Pdata的不重叠,那么我们的JS衡量相当于一个常数 l o g 2 log2 log2,而让我们直观上感觉,下图中间部分会比较好一点比左面部分,因为它使得两个数据分布更加的接近了,这也是我们想要的结果,但是我们使用JS来衡量数据分布直接接近程度,却在在它们不重叠的时候一直是一个常数,这使我们无法通过JS来判断我们 P G P_G PG P d a t a P_{data} Pdata的有多接近,我们失去了衡量彼此接近程度的标准。(为什么等于log2)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-RZjw23Gm-1573387985080)(JS problem.png)]

新的衡量距离

EM距离的原理

​   所以,WGAN直接提出一种新的衡量距离的方法:Earth Mover’s Distance(推土机距离),我们把数据分布P作为一方土,把另一个数据分布Q作为目标。我们把P通过推土机移动到目标Q上。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-XkeugO9U-1573387985080)(/home/gavin/Documents/WGAN/Imgs/EM.png)]

​   那么,我们把数据分布P移动成数据分布Q存在很多种方案,我们可以使最大移动量或者最小的移动量,在这里我们限定使用最小移动量来作为衡量标准,即我们希望最小的移动量达到最优的拟合结果。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-onheklow-1573387985081)(Imgs/EM.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-EOigrCSY-1573387985082)(Imgs/Best move.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-7xsxRPhP-1573387985083)(Imgs/矩阵移动.png)]

为什么EM距离好?

​   我们首先来有以下两个概率密度函数,我们通过缩小 θ \theta θ,使得 θ \theta θ作为我们的距离衡量标准。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-TijD7ue4-1573387985083)(Imgs/移动距离.png)]

​   下面是几种距离衡量的标准,我们可以直观的观察得到,当 θ \theta θ在不断的缩减过程中,EM距离是一个连续的值,尽管数据分布不重叠,但是,在JS和KL等距离下无法得到合适的测度,因为它们不连续,并且是一个跳跃的值。
W ( P 0 , P θ ) = ∣ θ ∣ J ( P 0 , P θ ) = { log ⁡ 2  if  θ ≠ 0 0  if  θ = 0 K L ( P θ ∥ P 0 ) = K L ( P 0 ∥ P θ ) = { + ∞  if  θ ≠ 0 0  if  θ = 0  and  δ ( P 0 , P θ ) = { 1  if  θ ≠ 0 0  if  θ = 0 \begin{array}{l}{W\left(\mathbb{P}_{0}, \mathbb{P}_{\theta}\right)=|\theta|} \\ {J\left(\mathbb{P}_{0}, \mathbb{P}_{\theta}\right)=\left\{\begin{array}{ll}{\log 2} & {\text { if } \theta \neq 0} \\ {0} & {\text { if } \theta=0}\end{array}\right.} \\ {K L\left(\mathbb{P}_{\theta} \| \mathbb{P}_{0}\right)=K L\left(\mathbb{P}_{0} \| \mathbb{P}_{\theta}\right)=\left\{\begin{array}{ll}{+\infty} & {\text { if } \theta \neq 0} \\ {0} & {\text { if } \theta=0}\end{array}\right.} \\ {\text { and } \delta\left(\mathbb{P}_{0}, \mathbb{P}_{\theta}\right)=\left\{\begin{array}{ll}{1} & {\text { if } \theta \neq 0} \\ {0} & {\text { if } \theta=0}\end{array}\right.}\end{array} W(P0,Pθ)=θJ(P0,Pθ)={log20 if θ=0 if θ=0KL(PθP0)=KL(P0Pθ)={+0 if θ=0 if θ=0 and δ(P0,Pθ)={10 if θ=0 if θ=0
​ 下面我直接使用原paper中的一部分,更加直观来比较JS和EM距离。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-prhxVsZH-1573387985084)(Imgs/JS and EM.png)]
Figure 1: These plots show ρ(Pθ, P0) as a function of θ when ρ is the EM distance (left plot) or the JS divergence (right plot). The EM plot is continuous and provides a usable gradient everywhere. The JS plot is not continuous and does not provide a usable gradient.

WGAN

从EM过渡到WGAN

​   所以,我们基于EM距离提出了WGAN,我们提出了有约束的判别器(满足1-Lipschitz),而Lipschitz连续条件限制了一个连续函数的最大局部变动幅度。然后,最大化 V ( G , D ) V(G, D) V(G,D),而我们为了满足约束条件采取了一种非常暴力的方法"weight clipping", 原论文中也说了这是一种非常槽糕的方式去使得判别器满足这个约束。权重裁剪的方式也很简单,只需要在反向传播把更新的权重强制夹到一个范围内就可以。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1Y5sCWH9-1573387985085)(Imgs/WGAN.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-5dJfRqzO-1573387985086)(Imgs/weight clipping.png)]

​   当然这样的权重裁剪,适当的值也是非常重要的,所以原论文给出了权重裁剪过大或者过小时出现的问题,如下:

Weight clipping is a clearly terrible way to enforce a Lipschitz constraint. If the clipping parameter is large, then it can take a long time for any weights to reach their limit, thereby making it harder to train the critic till optimality. If the clipping is small, this can easily lead to vanishing gradients when the number of layers is big, or batch normalization is not used (such as in RNNs).

WGAN算法

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-k39G1c28-1573387985086)(./Imgs/算法原理.png)]

​ 所以,可以看出对比原始的GAN,WGAN只改了以下四个部分:

  1. 判别器最后一层去掉sigmoid

  2. 生成器和判别器的loss不取log

  3. 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c

  4. 不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行

对于上述第四点原paper中也做了相关的阐述,如下:

​ Finally, as a negative result, we report that WGAN training becomes unstable attimes when one uses a momentum based optimizer such as Adam [8] (with β1 > 0)on the critic, or when one uses high learning rates. Since the loss for the critic isnonstationary, momentum based methods seemed to perform worse. We identifiedmomentum as a potential cause because, as the loss blew up and samples got worse,the cosine between the Adam step and the gradient usually turned negative. Theonly places where this cosine was negative was in these situations of instability. Wetherefore switched to RMSProp [21] which is known to perform well even on verynonstationary problems [13]

而在真实的梯度更新的过程中,我们也能从下图中看到不同生成器在最优的判别器下的梯度更新情况

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jtNPhSoX-1573387985087)(Imgs/梯度更新.png)]
Figure 2: Optimal discriminator and critic when learning to differentiate two Gaussians.As we can see, the discriminator of a minimax GAN saturates and results in vanishing gradients. Our WGAN critic provides very clean gradients on all parts of the space.

Pytorch 复现

代码改动

​   本次复现只在上一个版本上进行了局部的改动,这也如前面所说只需改动原始GAN算法的四个位置即可,改动结果如下:

  1. 判别器最后一层去掉sigmoid

    class NetD(nn.Module):
        """
        构建一个判别器,相当与一个二分类问题, 生成一个值
        """
    
        def __init__(self, opt):
            super(NetD, self).__init__()
    
            ndf = opt.ndf
            self.main = nn.Sequential(
                # 输入96*96*3
                nn.Conv2d(3, ndf, 5, 3, 1, bias=False),
                nn.LeakyReLU(0.2, inplace=True),
    
                # 输入32*32*ndf
                nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 2),
                nn.LeakyReLU(0.2, True),
    
                # 输入16*16*ndf*2
                nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 4),
                nn.LeakyReLU(0.2, True),
    
                # 输入为8*8*ndf*4
                nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 8),
                nn.LeakyReLU(0.2, True),
    
                # 输入为4*4*ndf*8
                nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=True),
    
                # 去除最后一层的sigmoid
                # nn.Sigmoid()
    
            )
    
        def forward(self, x):
            return self.main(x)
    
  2. 生成器和判别器的loss不取log

    生成器loss

    G_loss = -1 * (t.mean(netd(gen_img)))
    

    判别器loss

    D_loss = -1 * t.mean(netd(real_img)) + t.mean(netd(fake_img))
    
  3. 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c

     for p in netd.parameters():
                        p.data.clamp_(-opt.clip_value, opt.clip_value)  # opt.clip_value = 0.01
    
  4. 不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行

    optimizer_g = t.optim.SGD(netg.parameters(), lr=opt.lr1)
    optimizer_d = t.optim.SGD(netd.parameters(), lr=opt.lr2)
    
数据分析
  • 权重分布

    ​   实验过程中对判别器的权重进行收集,如下图,其中weight 1表示判别器第一层的卷积权重分布,以此类推,我在这里取了四个层的权重来进行对比,可以看出经过权重裁剪之后的权重分布偏向两侧。总感觉这样太暴力了,不过这种暴力裁剪的方法已经在WGAN-GP中已经得到解决了,最近在跟进。

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fQC1z7tV-1573387985088)(Imgs/weiht1.png)][外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-tm6PVy6D-1573387985089)(Imgs/weight5.png)]
    weight 1weight 5
    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Cxm2LcMq-1573387985090)(Imgs/weight9.png)][外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-2npHdyhl-1573387985090)(Imgs/weight11.png)]
    weight9weight 11
  • Loss

    D_loss:

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-JvnLnTSD-1573387985091)(Imgs/D_loss.png)]

    G_loss:

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-GFFaOmeY-1573387985091)(Imgs/G_loss.png)]

实验效果

​   下图是经过8000个epoch的效果,效果不是太好,可能是训练次数太少的效果,毕竟8000个epoch对于炼丹来说,还是差点意思,但是,我们主要是学习思想和方法,当然如果能做到好的实验效果也不能偷懒,哈哈!

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-noJB6nQB-1573387985092)(Imgs/7999.png)]

结论

WGAN从数学的角度,层层分析原始GAN所存在的问题,并且提出一种新的测度,这使得GAN更加具有鲁棒性和稳定性,自己在学习过程中也深感数学之伟大。如果在翻阅本博客时,看到错误的地方请即使指出,Pytorch代码我已经放到本人GitHub上,链接在下面参考文献中。与君共勉。

参考文献

令人拍案叫绝的Wasserstein GAN

Wasserstein GAN

WGAN的来龙去脉

W-GAN系 (Wasserstein GAN、 Improved WGAN)

pytorch WGAN

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值