深度学习:GAN(2)

本文介绍了Wasserstein GAN(WGAN)的概念,通过一个使用混合高斯分布的简单示例展示了WGAN的训练过程。文章详细阐述了生成器G和判别器D的构建,并解释了如何在训练中优化D和G。WGAN解决了原始GAN损失函数的弥散问题,通过引入梯度惩罚项改进了训练效果。提供了完整的代码示例,便于读者在实践中理解WGAN。
摘要由CSDN通过智能技术生成

上一节说明了GAN的训练流程和损失函数,现在用Pytorch来写下简单2017年提出的WGAN。

为了方便学习分布,我们用简单的混合高斯分布来作为真实的数据分布。

1 导入常用的包
import torch
from torch import nn, optim, autograd
 # 模型,优化器和自动求导的包,这里我们需要手动对loss求导所需要这个包
import numpy as np
import visdom
from torch.nn import functional as F
from matplotlib import pyplot as plt
import random   
2 超参数设置

根据自己的电脑配置设置。

h_dim = 400
batchsz = 512
viz = visdom.Visdom()
3 生成器G

首先输入是自己定义的维度,z:[b,2],像生成图片,我们可以设置为100,在这里我们是生成高斯分布的数据,因此设置为2。可以从代码看出一共有4层。

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
					            nn.Linear(2, h_dim),
					            nn.ReLU(True),
					            nn.Linear(h_dim, h_dim),
					            nn.ReLU(True),
					            nn.Linear(h_dim, h_dim),
					            nn.ReLU(True),
					            nn.Linear(h_dim, 2), #这里输出的维度也是为2,为了可视化
        						)
    def forward(self, z):
        output = self.net(z)
        return output
4 判别器D

可以看到,判别器也是4层,输出的是概率值(标量)。

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
					            nn.Linear(2, h_dim),
					            nn.ReLU(True),
					            nn.Linear(h_dim, h_dim),
					            nn.ReLU(True),
					            nn.Linear(h_dim, h_dim),
					            nn.ReLU(True),
					            nn.Linear(h_dim, 1)  #输出标量
					            nn.Sigmoid()
					          )
    def forward(self, x):
        output = self.net(x)
        return output.view(-1)
5 生成数据集

不同于MINSIT、CIFAR10这种数据集,这个例子用的是8个混合高斯分布,便于可视化。

def data_generator():
    scale = 2.
    centers = [
		        (1, 0),
		        (-1, 0),
		        (0, 1),
		        (0, -1),
		        (1. / np.sqrt(2), 1. / np.sqrt(2)),
		        (1. / np.sqrt(2), -1. / np.sqrt(2)),
		        (-1. / np.sqrt(2), 1. / np.sqrt(2)),
		        (-1. / np.sqrt(2), -1. / np.sqrt(2))
		    ] #8个中心点
    centers = [(scale * x, scale * y) for x, y in centers]
    while True: #死循环
        dataset = []
        for i in range(batchsz):
            point = np.random.randn(2) * .02
            center = random.choice(centers) #随机选取一个点
            point[0] += center[0]
            point[1] += center[1]
            dataset.append(point)
        dataset = np.array(dataset, dtype='float32')
        dataset /= 1.414  # stdev
        yield dataset  #yield到这一步就停止,下次再接着这一步开始,不然就一直死循环。

可以详细看下yeild用法:yield的用法详解

6 开始写主函数,主要是如何训练
6.1 先优化D
def main():
    torch.manual_seed(23)
    np.random.seed(23)
    G = Generator().cuda(
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值