WGAN-gp模型——pytorch实现

论文传送门:https://arxiv.org/pdf/1704.00028.pdf

WGAN存在的问题:在WGAN中,为使得判别器D(x)满足Lipschitz连续条件,从而对网络参数进行了[-c,c]的区间限制,使得网络参数分布极端,参数均接近于-c或c。

WGAN-gp的目的:解决WGAN参数分布极端的问题。 

WGAN-gp的方法:在判别器D的loss中增加梯度惩罚项,代替WGAN中对判别器D的参数区间限制,同样能保证D(x)满足Lipschitz连续条件。(证明过程见论文补充材料)

红框部分:与WGAN不同之处,即判别器D的loss增加梯度惩罚项和优化器选择Adam

梯度惩罚项的计算实现见代码70-87行,判别器D的损失函数修改见代码156行。

import os
import torch
from torch.utils.data import DataLoader

import torch.nn as nn

from torchvision import datasets, transforms
from torchvision.utils import save_image

from tqdm import tqdm


class Discriminator(nn.Module):  # 定义判别器(WS-divergence)
    def __init__(self, img_shape=(1, 28, 28)):  # 初始化方法
        super(Discriminator, self).__init__()  # 继承初始化方法
        self.img_shape = img_shape  # 图片形状

        self.linear1 = nn.Linear(self.img_shape[0] * self.img_shape[1] * self.img_shape[2], 512)  # linear映射
        self.linear2 = nn.Linear(512, 256)  # linear映射
        self.linear3 = nn.Linear(256, 1)  # linear映射
        self.leakyrelu = nn.LeakyReLU(0.2, inplace=True)  # leakyrelu激活函数

    def forward(self, x):  # 前传函数
评论 28
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CV_Peach

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值