上一节说明了GAN的训练流程和损失函数,现在用Pytorch
来写下简单2017年提出的WGAN。
为了方便学习分布,我们用简单的混合高斯分布来作为真实的数据分布。
1 导入常用的包
import torch
from torch import nn, optim, autograd
# 模型,优化器和自动求导的包,这里我们需要手动对loss求导所需要这个包
import numpy as np
import visdom
from torch.nn import functional as F
from matplotlib import pyplot as plt
import random
2 超参数设置
根据自己的电脑配置设置。
h_dim = 400
batchsz = 512
viz = visdom.Visdom()
3 生成器G
首先输入是自己定义的维度,z:[b,2],像生成图片,我们可以设置为100,在这里我们是生成高斯分布的数据,因此设置为2。可以从代码看出一共有4层。
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.net = nn.Sequential(
nn.Linear(2, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, 2), #这里输出的维度也是为2,为了可视化
)
def forward(self, z):
output = self.net(z)
return output
4 判别器D
可以看到,判别器也是4层,输出的是概率值(标量)。
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.net = nn.Sequential(
nn.Linear(2, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, 1) #输出标量
nn.Sigmoid()
)
def forward(self, x):
output = self.net(x)
return output.view(-1)
5 生成数据集
不同于MINSIT、CIFAR10这种数据集,这个例子用的是8个混合高斯分布,便于可视化。
def data_generator():
scale = 2.
centers = [
(1, 0),
(-1, 0),
(0, 1),
(0, -1),
(1. / np.sqrt(2), 1. / np.sqrt(2)),
(1. / np.sqrt(2), -1. / np.sqrt(2)),
(-1. / np.sqrt(2), 1. / np.sqrt(2)),
(-1. / np.sqrt(2), -1. / np.sqrt(2))
] #8个中心点
centers = [(scale * x, scale * y) for x, y in centers]
while True: #死循环
dataset = []
for i in range(batchsz):
point = np.random.randn(2) * .02
center = random.choice(centers) #随机选取一个点
point[0] += center[0]
point[1] += center[1]
dataset.append(point)
dataset = np.array(dataset, dtype='float32')
dataset /= 1.414 # stdev
yield dataset #yield到这一步就停止,下次再接着这一步开始,不然就一直死循环。
可以详细看下yeild用法:yield的用法详解。
6 开始写主函数,主要是如何训练
6.1 先优化D
def main():
torch.manual_seed(23)
np.random.seed(23)
G = Generator().cuda(