AI入门:神经网络实战----WGAN

前言
在大数据时代,数据的质量和数量是最重要的。监督学习中的样本需要我们人工进行标注,而人工标注这个工作本身是非常枯燥乏味,且工作量巨大的事情。如果我们的神经网络使用非监督学习,那就可以避免标注工作,节约很多成本,同时获得更多的数据。因此非监督学习是未来的一个重要发展方向。上一节我们学的GAN就是一种非监督学习的神经网络,因此学习和改进GAN是非常重要的。

在上一节中,我们提到了GAN的一些缺点:
1、GAN要求G和D相互匹配。这一点在实际编程中是比较困难的。
2、GAN的训练比较缓慢。
3、准确率以及损失值都是不能正确估量模型的真实情况。

其中最困难的是要求G和D在训练过程中一直保持匹配。如果不匹配,则该神经网络就容易崩溃。
什么是WGAN?
WGAN由两篇论文提出,主要是数学方面的推导。我们跳过数学推导的内容,直接看结论。WGAN没有改变GAN的框架,只是对GAN进行四方面的改进:
1、判别器最后一层去掉sigmoid。
2、生成器和判别器的loss不取log。
3、每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c。后来又改为一个梯度惩罚项,相当于一个正则。
4、不要用基于动量的优化算法 (包括momentum和Adam),推荐RMSProp,SGD也行。

有了上面感性的认识之后,我们就开始实现WGAN:
Generative Network的实现:
WGAN的生成器的代码与GAN的生成器代码完全一致:就是输入一个随机噪音数据z_prior,然后经过3层全连接层,得到28 * 28 = 784维的一张图片。这张图片如果经过变形,就可以变成(28, 28)的图片,尺寸与MNIST数据集中的图片有一样的尺寸。

在这里插入图片描述

Discriminator Network的实现:
相比于GAN,WGAN的判别器的最后一层少了Sigmoid()函数:

在这里插入图片描述

WGAN的训练:
WGAN也是先训练判别器。为了训练判别器,我们必须先生成生成器的实例G和判别器的实例D。然后定义D的损失函数和优化器。这里的优化器要使用不带动量的函数,推荐使用RMSprop()。

在这里插入图片描述

下面要对判别器进行训练。这里跟GAN区别较大,GAN直接使用BCELoss()作为损失函数,但这里要自己计算损失。我们从MNIST数据集中获取到一个batch的真实图片img。把图片转化成一维数据后转化成Tensor对象。接着用判别器对这些真实图片进行预测,得到计算结果。

在这里插入图片描述

接着用噪音z生成假的图片,然后计算判别器对假图片的预测结果。与GAN中一样,这个过程中的生成器要用detach()处理。

在这里插入图片描述

这里要出现第三个改进:梯度惩罚项,也就是一个正则gp (在源代码中看gp的计算方法)。得到真假图片的预测值、梯度惩罚项之后,就可以计算总的损失了。注意,这里使用了Wasserstein距离这一个概念。也就是计算真假图片的预测值的差。系数0.5是可以自由调节的。

在这里插入图片描述

有了损失值之后,就可以从判别器的损失值开始进行反向传播,参数优化了。值得注意的是,由于WGAN网络不需要考虑G和D的匹配问题,所以我们可以先让判别器多训练几次,这样一个比较好的判别器可以让G更快的成长。

在这里插入图片描述

训练了判别器之后,就可以训练生成器了。训练生成器的过程:先获得噪音z,然后用G生成假图片。这里要训练G,所以这里不能用detach()处理。接下去是求判别器的计算结果。由于我们希望生成器的损失值变小,所以要对判别器的计算结果求负值。在得到损失值之后,就可以进行反向传播和参数优化了。至此,终于完成生成器和判别器的训练了。

在这里插入图片描述

总结:
与GAN相比,WGAN生成的图片的效果并没有太大的改进。WGAN最大的优势在于从理论上解决了GAN中G和D必须匹配,否则模型就容易崩溃的问题。从而使我们不再需要小心的设计网络,让我们可以在GAN中放心的使用VGG16、ResNet等成熟的模型,最终生成效果更好的图片。

在WGAN之后,还有BigGAN和StyleGAN等GAN的改进形式。BigGAN生成的图片的多样性好,画质非常的清晰、细腻;StyleGAN还能够控制所生成图像的高层级属性(high-level attributes),如发型、雀斑等。这些内容,我们将在后续课程中进行解释和实现。

from WGAN.WGAN_Discriminator import *
from WGAN.WGAN_Generator import *
from WGAN import WGAN_Resource

import torch
from torch import optim
from torch.autograd import Variable
from torch import autograd
from torchvision import datasets, models, transforms
import torchvision
from torchvision.utils import
  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值