论文传送门:https://arxiv.org/pdf/1704.00028.pdf
WGAN存在的问题:在WGAN中,为使得判别器D(x)满足Lipschitz连续条件,从而对网络参数进行了[-c,c]的区间限制,使得网络参数分布极端,参数均接近于-c或c。
WGAN-gp的目的:解决WGAN参数分布极端的问题。
WGAN-gp的方法:在判别器D的loss中增加梯度惩罚项,代替WGAN中对判别器D的参数区间限制,同样能保证D(x)满足Lipschitz连续条件。(证明过程见论文补充材料)
红框部分:与WGAN不同之处,即判别器D的loss增加梯度惩罚项和优化器选择Adam。
梯度惩罚项的计算实现见代码70-87行,判别器D的损失函数修改见代码156行。
import os
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tqdm import tqdm
class Discriminator(nn.Module): # 定义判别器(WS-divergence)
def __init__(self, img_shape=(1, 28, 28)): # 初始化方法
super(Discriminator, self).__init__() # 继承初始化方法
self.img_shape = img_shape # 图片形状
self.linear1 = nn.Linear(self.img_shape[0] * self.img_shape[1] * self.img_shape[2], 512) # linear映射
self.linear2 = nn.Linear(512, 256) # linear映射
self.linear3 = nn.Linear(256, 1) # linear映射
self.leakyrelu = nn.LeakyReLU(0.2, inplace=True) # leakyrelu激活函数
def forward(self, x): # 前传函数