WGAN-算法和代码结合点

写这篇文章不是为了介绍WGAN,而是为了把WGAN和Code对应起来。因为最后还是要付诸于实践,而不仅仅就是 介绍什么,因为关于WGAN好的介绍已经很多了。

如果想看懂,请认真仔细看。鄙人也是查询并收集了各种资料,学习了好久。同时也一步步推导过,所以这是一个慢慢的任务,认真看,别太着急。等你把这个看懂了并且可以自己写出来的时候,表示你已经看懂了。

看懂了,点个赞再走呗!!!

目录

一、WGAN比GAN的优越性

二、WGAN

1、基础理解

2、参数解释

3、代码解释

4、算法解释

一、WGAN比GAN的优越性

Wasserstein GAN(下面简称WGAN)成功地做到了以下爆炸性的几点:

  1. 彻底解决GAN训练不稳定的问题,不再需要小心平衡生成器和判别器的训练程度
  2. 基本解决了collapse mode的问题,确保了生成样本的多样性
  3. 训练过程中终于有一个像交叉熵、准确率这样的数值来指示训练的进程,这个数值越小代表GAN训练得越好,代表生成器产生的图像质量越高
  4. 以上一切好处不需要精心设计的网络架构,最简单的多层全连接网络就可以做到

二、WGAN

1、基础理解

2、参数解释

从公式14——>公式15

重点 

构造一个含参数w、最后一层不是sigmoid的判别器网络fw,在限制w不超过某个范围的条件下,使得公式15尽可能取到最大,此时L就会近似真实分布与生成分布之间的Wasserstein距离(忽略常数倍数)

 

在判别器中:

公式17是公式15的变形(取负),上面提到过“使得公式15尽可能取到最大”,同时这还是一个最大最小的博弈问题(original GAN ),但是在code中往往是最小化损失函数,所以采用公式17的损失函数这里取负,用来更新判别器的参数。

在生成器中:

损失函数,min max L,在判别器中,已经更新了fw之后,进入了生成器(此时fw是固定的),公式15中,第一项和生成器无关,所以公式16中fw(x)等价于fw(g(z)),进而来更新生成器的权重参数。

3、代码解释

WGAN与原始GAN第一种形式相比,只改了四点:

  • 判别器最后一层去掉sigmoid
  • 生成器和判别器的loss不取log
  • 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c
  • 不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行
     

1)判别器最后一个去掉了sigmoid

2)生成器和判别器的loss不取log

     

3)每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c

4)不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行

4、算法解释

preview

Require:w表示判别器参数(这里采用critic 代替了Discriminator),θ表示生成器参数(Generator)

判别器:n_{critic} = 5

2:n_{critic}说明的是判别器更新n_{critic}次,然后生成器才更新一次,这里n_{critic} = 5

3-4:表示采样

5:采用公式17,计算损失函数,来更新判别器的参数w,这里的w指的是判别器网络中每一层的参数

6:梯度更新优化w

7:每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c

生成器:

10:生成器的损失函数,根据判别器的fw,构造出更新θ的损失函数

11:梯度更新优化θ

WGAN-GP是一种改进的生成对抗网络(GAN)模型,它在原始的Wasserstein GAN基础上添加了梯度惩罚(Gradient Penalty)项。PyTorch是一个开源的深度学习框架,用于实现和训练神经网络模型。 WGAN-GP的基本思想是通过训练一个生成器和一个判别器来实现生成新样本的目标。生成器尝试产生与真实样本相似的样本,而判别器则努力区分生成样本和真实样本。Wasserstein GAN使用Earth-Mover(EM)距离作为判别器的损失函数,以提升训练稳定性。然而,EM距离的计算涉及到判别器的Lipschitz约束,这个约束很难满足,而且难以实现。 WGAN-GP则通过梯度惩罚项解决了Lipschitz约束的问题。梯度惩罚项是通过对真实样本和生成样本之间的线性插值进行随机采样,并对判别器输出的梯度进行惩罚来实现的。具体而言,用于计算梯度的范数的平方作为惩罚项,将梯度限制在一个合理的范围内。 在PyTorch中,可以使用torch.nn.Module类来定义生成器和判别器模型,并且可以使用torch.optim优化器来更新参数。通过在训练过程中交替更新生成器和判别器,逐步提升生成样本的质量。 WGAN-GP的PyTorch实现包括以下步骤: 1. 定义生成器和判别器的网络结构。 2. 定义损失函数,其中包括Wasserstein距离和梯度惩罚项。 3. 定义优化器,如Adam或SGD。 4. 进行训练迭代,包括前向传播生成样本,计算损失,反向传播和参数更新。 总之,WGAN-GP是一种改进的GAN模型,在PyTorch中可以轻松实现和训练。它通过引入梯度惩罚项解决了Lipschitz约束的问题,使得训练过程更加稳定,并且能够生成更高质量的样本。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值