wgan-gp

import random
import numpy as np
import torch
from matplotlib import pyplot as plt
from torch import nn, optim, autograd
from visdom import Visdom

# 生成real-data数据集
def data_generator():
    """
        预设数据样本分布为8个高斯分布叠加的分布模型
    """
    scale = 2.
    centers = [
        (1, 0), (-1, 0), (0, 1), (0, -1),
        (1. / np.sqrt(2), 1. / np.sqrt(2)),
        (1. / np.sqrt(2), -1. / np.sqrt(2)),
        (-1. / np.sqrt(2), 1. / np.sqrt(2)),
        (-1. / np.sqrt(2), -1. / np.sqrt(2))
    ]
    centers = [(scale * x, scale * y) for x, y in centers]
    while True:
        dataset = []
        for i in range(batch_size):
            point = np.random.randn(2) * 0.02
            center = random.choice(centers)
            point[0] += center[0]
            point[1] += center[1]
            dataset.append(point)
        dataset = np.array(dataset, dtype='float32')
        dataset /= 1.414  # stdev
        yield dataset

# hyper-parameters
hidden_dim = 200
batch_size = 256
epochs = 5000

# visdom object
vis = Visdom()
# device
device = torch.device("cuda")


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # 输入为[batch, 2],这里2指随机生成的二维点
            nn.Linear(2, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(True),
            # 这里2指Generator生成的尽量满足真实数据分布的fake-data
            nn.Linear(hidden_dim, 2)
        )

    def forward(self, z):
        """
        :param z:  [batch, 2] 随机生成的二维点
        :return:   [batch, 2] Generator生成的尽量满足真实数据分布的fake-data
        """
        output = self.net(z)
        return output

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            # [batch, 2], 2表示二维数据点(real-data 或 fake-data)
            nn.Linear(2, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, 1),
            # 概率高表示判定为real-data,概率低判定为fake-data
            nn.Sigmoid()
        )

    def forward(self, x):
        """
            输入二维数据点,判断是否满足预定义分布(real-data or fake-data)
        :param x:
        :return:
        """
        output = self.net(x)
        return output


def weights_init(m):
    if isinstance(m, nn.Linear):
        # m.weight.data.normal_(0.0, 0.02)
        nn.init.kaiming_normal_(m.weight)
        m.bias.data.fill_(0)


def gradient_penalty(D, xr, xf):
    LAMBDA = 0.3
    # only constrait for Discriminator
    xf = xf.detach()
    xr = xr.detach()
    # [b, 1] => [b, 2]
    alpha = torch.rand(batch_size, 1).to(device)
    alpha = alpha.expand_as(xr)

    interpolates = alpha * xr + (1 - alpha) * xf
    interpolates.requires_grad_()

    pred = D(interpolates)

    gradients = autograd.grad(outputs=pred, inputs=interpolates,
                              grad_outputs=torch.ones_like(pred),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]
    gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
    return gp


def train_GAN():
    # 使得每次训练生成的随机数比较稳定
    torch.manual_seed(23)
    np.random.seed(23)

    G = Generator().to(device)
    D = Discriminator().to(device)
    G.apply(weights_init)
    D.apply(weights_init)
    optimizer_G = optim.Adam(G.parameters(), lr=1e-3, betas=(0.5, 0.9))
    optimizer_D = optim.Adam(D.parameters(), lr=1e-3, betas=(0.5, 0.9))

    data_iter = data_generator()
    vis.line([[0., 0.]], [0.], win='Loss Info', opts=dict(label='Loss Info', 
                                                legend=['Loss_G', 'Loss_D']))

    # Generator网络和Discriminator网络【交替分步训练】
    # 每个epoch中Discriminator网络训练 k 次
    for epoch in range(epochs):
        # 1. train Discriminator for k step
        for _ in range(5):
            # real-data loss
            x = next(data_iter)
            xr = torch.from_numpy(x).to(device)
            predr = D(xr)  # 为real-data的概率预测值
            lossr = -predr.mean()

            # fake-data loss
            x_random = torch.randn(batch_size, 2).to(device)
            xf = G(x_random).detach()  # 返回 tensor.data
            predf = D(xf)
            lossf = predf.mean()

            # gradient penality
            gp = gradient_penalty(D, xr, xf)

            # 梯度更新
            loss_D = lossr + lossf + gp
            optimizer_D.zero_grad()
            loss_D.backward()
            optimizer_D.step()

        # 2. train Generator
        x_random = torch.randn(batch_size, 2).to(device)
        xf = G(x_random)
        predf = D(xf)
        loss_G = -predf.mean()
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        # 显示训练信息
        if (epoch + 1) % 100 == 0:
            vis.line([[loss_G.item(), loss_D.item()]], [epoch], 
                win='Loss Info', update='append')
            print("epoch:%-5i" % (epoch + 1), "Loss_G=%-5.5f" % loss_G.item(), 
                "Loss_D=%-5.5f" % loss_D.item())


if __name__ == '__main__':
    train_GAN()
    print("Done!")

 

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值