文献阅读:WGAN

摘要

自从2014年Ian Goodfellow提出以来,GAN就存在着训练困难、生成器和判别器的loss无法指示训练进程、生成样本缺乏多样性等问题。从那时起,很多论文都在尝试解决,但是效果不尽人意,比如最有名的一个改进DCGAN依靠的是对判别器和生成器的架构进行实验枚举,最终找到一组比较好的网络架构设置,但是实际上是治标不治本,没有彻底解决问题,而Wasserstein GAN彻底解决GAN训练不稳定的问题,不再需要小心平衡生成器和判别器的训练程度,还基本解决了collapse mode的问题,确保了生成样本的多样性,并且训练过程中终于有一个像交叉熵、准确率这样的数值来指示训练的进程,这个数值越小代表GAN训练得越好,代表生成器产生的图像质量越高,本文主要介绍GAN原始存在的问题和WGAN实现突破的四个重点,并应用于代码。

Abstract

.Since Ian Goodfellow proposed GAN in 2014, it has faced difficulties in training, inability of the generator and discriminator losses to indicate the training progress, and lack of diversity in generated samples. Since then, many papers have attempted to address these issues, but with unsatisfactory results. For example, the most famous improvement, DCGAN, relied on experimenting with different architectures for the discriminator and generator, eventually finding a set of network architecture settings that worked relatively well. However, this was just a temporary solution and did not fundamentally solve the problem. On the other hand, Wasserstein GAN completely addresses the issue of unstable GAN training. It no longer requires careful balancing of the training of the generator and discriminator, and it also largely solves the problem of collapse mode, ensuring diversity in generated samples. Additionally, during the training process, there is finally a numerical value, similar to cross-entropy and accuracy, that indicates the progress of training. A smaller value indicates better GAN training and higher quality of generated images. This article primarily discusses the original problems of GAN and the four key breakthroughs achieved through WGAN implementation, which are also applied in the provided code.

Wasserstein GAN

文献:https://arxiv.org/abs/1701.07875

1、原始GAN中存在的问题

实际训练中,GAN存在着训练困难、生成器和判别器的loss无法指示训练进程、生成样本缺乏多样性等问题。这与GAN的机制有关。GAN最终达到对抗的纳什均衡只是一个理想状态,而现实情况中得到的结果都是中间状态(伪平衡)。

大部分的情况是,随着训练的次数越多判别器D的效果越好,会导致一直可以将生成器G的输出与真实样本区分开。这是因为生成器G是从低维空间向高维空间(复杂的样本空间)映射,其生成的样本分布空间Pg难以充满整个真实样本的分布空间Pr。即两个分布完全没有重叠的部分,或者它们重叠的部分可以忽略,这样就使得判别器D总会将它们分开。

在原始GAN的训练中,判别器训练得太好,会使生成器梯度消失,生成器loss降不下去;判别器训练得不好,会使生成器梯度不准,四处乱跑。只有判别器训练到中间状态最佳,但是这个尺度很难把握,没有一个收敛判断的依据。甚至在同一轮训练的前后不同阶段,这个状态出现的时段都不一样,是个完全不可控的情况。

引入Kullback–Leibler divergence(简称KL散度)和Jensen-Shannon divergence(简称JS散度)这两个重要的相似度衡量指标,后面的主角之一Wasserstein距离,就是要来吊打它们两个的。所以接下来介绍这两个重要的配角——KL散度和JS散度:

根据原始GAN定义的判别器loss,我们可以得到最优判别器的形式;而在最优判别器的下,我们可以把原始GAN定义的生成器loss等价变换为最小化真实分布与生成分布之间的JS散度。我们越训练判别器,它就越接近最优,最小化生成器的loss也就会越近似于最小化和之间的JS散度。
img

这过程看似非常的合理,只要我们不断的训练,真实分布和生成分布越来越接近,JS散度应该越来越小,直到两个分布完全一致,此时JS散度为0。但是理想很丰满,现实很骨感,JS散度并不会随着真实分布和生成分布越来越近而使其值越来越小,而是很大概率保持log2不变。

为什么两个分布不重合,它们JS散度一直为log2?

1705839157321

现在大家注意到上式已经有一个log2了,为证明我们的结论,我们只需证明两个分布不重合时上式左侧两项为0即可。我们以下图为例:img

1705839181345

2.研究内容

2.1 EM distance(推土机距离)的引入

我们上文详细分析了普通GAN存在的缺陷,主要是由于和JS散度相关的损失函数导致的。大佬们就在思考能否有一种损失能够代替JS散度呢?于是,WGAN应运而生,其提出了一种新的度量两个分布距离的标准——Wasserstein Metric,也叫推土机距离(Earth-Mover distance)。

下图左侧有6个盒子,我们期望将它们都移动到右侧的虚线框内。比如将盒子1从位置1移动到位置7,移动了6步,我们就将此步骤的代价设为6;同理,将盒子5从位置3移动到位置10,移动了7步,那么此步骤的代价为7,依此类推。

1

很显然,我们有很多种不同的方案,下图给出了两种不同的方案:

2

上图中右侧的表格表示盒子是如何移动的。比如在γ1 中,其第一行第一列的值为1,表示有1个框从位置1移动到了位置7;第一行的第四列的值为2,表示有2个框从位置1移动到了位置10。上图中给出的两种方案的代价总和都为42,但是对于一个问题并不是所有的移动方案的代价都是固定的,比如下图:

3

在这个例子中,上图展示的两种方案的代价是不同的,一个为2,一个为6。,而推土机距离就是穷举所有的移动方案,最小的移动代价对应的就是推土机距离。对应本列来说,推土机距离等于2。

现给出推土机距离是数学定义,如下:

W ( P r , P g ) W(P_r,P_g) W(Pr,Pg)=KaTeX parse error: Limit controls must follow a math operator at position 7: {\inf}\̲l̲i̲m̲i̲t̲s̲_{\gamma\in \pr…

∏ ( P r , P g ) \prod(P_r,P_g) (Pr,Pg)表示边缘分布P r 和 P g 所有组合起来的联合分布 γ \gamma γ(x,y)的集合。我们还是用图1中的例子来解释, ∏ ( P r , P g ) \prod(P_r,P_g) (Pr,Pg)就表示所有的运输方案 γ \gamma γ,下图仅列举了两种方案:

4

E ( x , y ) [ ∣ ∣ x − y ∣ ∣ ] E_{(x,y)}[||x-y||] E(x,y)[∣∣xy∣∣]可以看成是对于一个方案 γ \gamma γ移动的代价,KaTeX parse error: Limit controls must follow a math operator at position 7: {\inf}\̲l̲i̲m̲i̲t̲s̲_{\gamma\in \pr… 就表示在所有的方案中的最小代价,这个最小代价就是 W ( P r , P g ) W(P_r,P_g) W(Pr,Pg),即推土机距离。
现在我们已经知道了推土机距离是什么,但是我们还没解释清楚我们为什么要用推土机距离,即推土机距离为什么可以代替JS散度成为更优的损失函数?我们来看这样的一个例子,如下图所示:

img

上图有两个分布P1和 P2 ,P1在线段AB上均匀分布,P2在CD上均匀分布,参数θ 可以控制两个分布的距离。我们由前文对JS散度的解释,可以得到:

J S ( P 1 ∣ ∣ P 2 ) JS(P1||P2) JS(P1∣∣P2)= { l o g 2 θ ≠ 0 0 θ = 0 \begin{cases} log2\quad\quadθ\neq0\\ 0 \quad\quad\quadθ=0 \end{cases} {log2θ=00θ=0

而对于推土机距离来说,可以得到:

W ( P 1 , P 2 ) W(P1,P2) W(P1,P2)=|θ|

这样对比可以看出,推土机距离是平滑的,这样在训练时,即使两个分布不重叠推土机距离仍然可以提高梯度,这一点是JS散度无法实现的。

3、WGAN的实现

现在我们已经有了推土机距离的定义,同时也解释了推土机距离相较于JS散度的优势。但是我们想要直接使用推土机距离来定义生成器的损失似乎是困难的,因为这个式子 W ( P r , P g ) W(P_r,P_g) W(Pr,Pg)=KaTeX parse error: Limit controls must follow a math operator at position 7: {\inf}\̲l̲i̲m̲i̲t̲s̲_{\gamma\in \pr… 是难以直接求解的,但是呢,作者用了一个定理将上式变化成了如下形式:

W ( P r , P g ) W(P_r,P_g) W(Pr,Pg)=1/k sup ⁡ ∣ ∣ f ∣ ∣ L < = K \sup\limits_{||f||_L<=K} ∣∣fL<=Ksup E x ∼ P r [ f ( x ) ] − E x ∼ P g [ f ( x ) ] E_{x\sim Pr}[f(x)]-E_{x\sim Pg}[f(x)] ExPr[f(x)]ExPg[f(x)]

注意上式中的f有一个限制,即 ∣ ∣ f ∣ ∣ L < = K ||f||_L<=K ∣∣fL<=K,我们称为lipschitz连续条件。这个限制其实就是限制了函数f的导数。它的定义如下: ∣ f ( x 1 ) − f ( x 2 ) ∣ ≤ k ∣ x 1 − x 2 ∣ |f(x1)-f(x2)|\le k|x1-x2| f(x1)f(x2)kx1x2∣

即:
∣ f ( x 1 ) − f ( x 2 ) ∣ x 1 − x 2 ∣ ≤ K \frac{|f(x1)-f(x2)}{|x1-x2|} \le K x1x2∣f(x1)f(x2)K

很显然,lipschitz连续就限制了f的斜率的绝对值小于等于K,这个K称为Libschitz常数。如下图所示:

image-20221005103024152

上图中log(x)的斜率无界,故log(x)不满足lipschitz连续条件;而sin(x)斜率的绝对值都小于1,故sin(x)满足lipschitz连续条件。这样,我们只需要找到一个lipschitz函数,就可以计算推土机距离了。至于怎么找这个lipschitz函数呢,就是我们搞深度学习的那一套啦,只需要建立一个深度学习网络来进行学习就好啦。实际上,我们新建立的判别器网络和之前的的基本是一致的,只是最后没有使用sigmoid函数,而是直接输出一个分数,这个分数可以反应输入图像的真实程度。

4.代码实现:

其实WGAN相较于原始GAN只做了4点改变,分别如下:

  1. 判别器最后不使用sigmoid函数
  2. 生成器和判别器的loss不取log
  3. 每次更新判别器参数后将判别器的权重截断
  4. 不适应基于动量的优化算法,推荐使用RMSProp

4.1 判别器最后不使用sigmoid函数

这个我们一般只需要删除判别器网络中的最后一个sigmoid层就可以了,非常简单。但是我还想提醒大家一下,有时候你在看别人的原始GAN时,他的判别器网络中并没有sigmoid函数,而是在定义损失函数时使用了BCEWithLogitsLoss函数,这个函数会先对数据做sigmoid,相关代码如下:

# 定义损失函数
criterion = nn.BCEWithLogitsLoss(reduction='mean')

4.2生成器和判别器的loss不取log

我们先来看原始GAN判别器的loss是怎么定义的,如下:

d_loss_real = criterion(d_out_real.view(-1), label_real)
d_loss_fake = criterion(d_out_fake.view(-1), label_fake)
d_loss = d_loss_real + d_loss_fake

原始GAN使用了criterion函数,这就是我们上文定义的BCEWithLogitsLossBCELoss,其内部是一个log函数

而WGAN判别器的损失直接是两个期望(均值)的差,就对应理论中 E x ∼ P r [ f ( x ) ] − E x ∼ P g [ f ( x ) ] E_{x\sim Pr}[f(x)]-E_{x\sim Pg}[f(x)] ExPr[f(x)]ExPg[f(x)]这部分,但是需要加个负号,因为我们都是使用的梯度下降方法,要最小化这个损失,上式是最大化的公式,我们来看看代码中的实现:

d_loss = -(torch.mean(d_out_real.view(-1))-torch.mean(d_out_fake.view(-1))) #判别器

g_loss = criterion(d_out_fake.view(-1), label_real)  #原始GAN损失
---------------------------------------------------
g_loss = -torch.mean(d_out_fake.view(-1))            #WGAN损失

4.3 每次更新判别器参数后将判别器的权重截断

首先来说说没什么要进行权重截断,这是因为lipschitz连续条件不好确定,作者为了方便直接简单粗暴的限制了权重参数在**[-c,c]**这个范围了,这样就一定会存在一个常数K使得函数f满足lipschitz连续条件,具体的实现代码如下:

# clip D weights between -0.01, 0.01  权重剪裁
  for p in D.parameters():
  	  p.data.clamp_(-0.01, 0.01)

注意这步是在判别器每次反向传播更新梯度结束后进行的。还需要提醒大家一点,在训练WGAN时往往会多训练几次判别器然后再训练一次生成器,论文中是训练5次判别器后训练一次生成器,关于这一点在上文WGAN的流程图中也有所体现。

4.4 不适应基于动量的优化算法,推荐使用RMSProp

作者做实验发现像Adam这类基于动量的优化算法效果不好,然后使用了RMSProp优化算法。代码如下:

optimizerG = torch.optim.RMSprop(G.parameters(), lr=5e-5)
optimizerD = torch.optim.RMSprop(D.parameters(), lr=5e-5)

4.5 github源码

import argparse
import os
import numpy as np
import math
import sys
 
import torchvision.transforms as transforms
from torchvision.utils import save_image
 
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
 
import torch.nn as nn
import torch.nn.functional as F
import torch
 
os.makedirs("images", exist_ok=True)
 
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.00005, help="learning rate")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--n_critic", type=int, default=5, help="number of training steps for discriminator per iter")
parser.add_argument("--clip_value", type=float, default=0.01, help="lower and upper clip value for disc. weights")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)
 
img_shape = (opt.channels, opt.img_size, opt.img_size)
 
cuda = True if torch.cuda.is_available() else False
 
 
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
 
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
 
        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )
 
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *img_shape)
        return img
 
 
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
 
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )
 
    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity
 
 
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
 
if cuda:
    generator.cuda()
    discriminator.cuda()
 
# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)
 
# Optimizers
optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=opt.lr)
optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=opt.lr)
 
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
 
# ----------
#  Training
# ----------
 
batches_done = 0
for epoch in range(opt.n_epochs):
 
    for i, (imgs, _) in enumerate(dataloader):
 
        # Configure input
        real_imgs = Variable(imgs.type(Tensor))
 
        # ---------------------
        #  Train Discriminator
        # ---------------------
 
        optimizer_D.zero_grad()
 
        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
 
        # Generate a batch of images
        fake_imgs = generator(z).detach()
        # Adversarial loss
        loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))
 
        loss_D.backward()
        optimizer_D.step()
 
        # Clip weights of discriminator
        for p in discriminator.parameters():
            p.data.clamp_(-opt.clip_value, opt.clip_value)
 
        # Train the generator every n_critic iterations
        if i % opt.n_critic == 0:
 
            # -----------------
            #  Train Generator
            # -----------------
 
            optimizer_G.zero_grad()
 
            # Generate a batch of images
            gen_imgs = generator(z)
            # Adversarial loss
            loss_G = -torch.mean(discriminator(gen_imgs))
 
            loss_G.backward()
            optimizer_G.step()
 
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, opt.n_epochs, batches_done % len(dataloader), len(dataloader), loss_D.item(), loss_G.item())
            )
 
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
        batches_done += 1

5.总结

在本周学习中通过对WGAN文献的阅读和李宏毅关于GAN的复习,我对GAN和WGAN有了一定的了解,同时也了解了使用Pytorch实现WGAN的步骤,进一步加深了对WGAN的理解。

  • 20
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值