PyTorch中的生成对抗网络(GAN)

2014年,蒙特利尔大学的伊恩·古德费洛(Ian Goodfellow)和他的同事发表了一篇惊人的论文,向世界介绍了GAN(即生成性对抗网络)。通过将计算图和博弈论进行创新的组合,他们表明,只要具有足够的建模能力,两个相互竞争的模型就可以通过简单的反向传播进行协同训练。

这些模型扮演两个不同的角色(从字面上看是对抗角色)。给定一些真实数据集R,G是生成器,试图创建看起来像真实数据的伪数据,而D是鉴别器,从真实集或G中获取数据并标记差异。古德费勒的比喻(也是一个很好的比喻)是:G就像是一群伪造者,他们试图将真实的绘画与其输出进行匹配,而D则是一群侦探,试图说出区别。(除非在这种情况下,伪造者G永远不会看到原始数据——只有d的判断。他们就像盲目的伪造者。)
在这里插入图片描述

在理想情况下,随着时间的推移,D和G都会变得更好,直到G基本上成为真品的“伪造大师”,而D则不知所措,“无法区分这两种分布”。

在实践中,古德费洛展示的是G可以在原始数据集上执行某种形式的无监督学习,找到某种方式(可能)以较低维度表示该数据。正如Yann LeCun所说的那样,无监督学习是真正AI的“蛋糕”。

这项功能强大的技术似乎仅需要很多代码才能入门,对吗?不。使用PyTorch,我们实际上可以在50行以下的代码中创建一个非常简单的GAN。实际上只有5个组件需要考虑:

  • R:原始的真实数据集
  • I:随机噪声作为噪声源进入发生器
  • G:试图复制/模仿原始数据集的生成器
  • D:鉴别器试图区分R的G输出
  • .在实际的“训练”循环中,我们教G欺骗D,而D要小心G。

1.)R:在我们的例子中,我们将从最简单的R(钟形曲线)开始。该函数取一个平均值和一个标准差,然后返回一个函数,该函数可提供具有这些参数的高斯样本数据的正确形状。在示例代码中,我们将使用平均值4.0和标准偏差1.25。

def get_distribution_sampler(mu, sigma):
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussian

2.)I:生成器的输入也是随机的,但是为了使我们的工作更加困难,让我们使用统一分布而不是正态分布。这意味着我们的模型G不能简单地移动/缩放输入以复制R,而是必须以非线性方式重塑数据。

def get_generator_input_sampler():
    return lambda m, n: torch.rand(m, n)  # Uniform-dist data into generator, _NOT_ Gaussian

3.)G:生成器是标准的前馈图-两个隐藏层,三个线性映射。我们使用的是双曲正切激活函数,因为我们像这样老派。摹会得到均匀的分布数据样本我,不知怎么模仿的通常从分布的样本[R -而没有看到[R 。

class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, f):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.f = f

    def forward(self, x):
        x = self.map1(x)
        x = self.f(x)
        x = self.map2(x)
        x = self.f(x)
        x = self.map3(x)
        return x

4.)D:鉴别器代码与G的生成器代码非常相似;具有两个隐藏层和三个线性映射的前馈图。这里的激活函数是一个sigmoid——没什么特别的,各位。它将从R或G中获取样本,并将输出介于0和1之间的单个标量,解释为“假”与“真实”。换句话说,这是神经网络所能做的最简单的事情。

class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, f):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.f = f

    def forward(self, x):
        x = self.f(self.map1(x))
        x = self.f(self.map2(x))
        return self.f(self.map3(x))

5.)最后,训练循环在两种模式之间交替进行:第一种是真实数据训练,另一种是带有精确标签的虚假数据训练(可以将其视为警察学院);然后用不准确的标签训练G去愚弄D(这更像是《十一罗汉》里的准备蒙太奇)。这是一场正义与邪恶的战争。

for epoch in range(num_epochs):
    for d_index in range(d_steps):
        # 1. Train D on real+fake
        D.zero_grad()

        #  1A: Train D on real
        d_real_data = Variable(d_sampler(d_input_size))
        d_real_decision = D(preprocess(d_real_data))
        d_real_error = criterion(d_real_decision, Variable(torch.ones([1,1])))  # ones = true
        d_real_error.backward() # compute/store gradients, but don't change params

        #  1B: Train D on fake
        d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
        d_fake_decision = D(preprocess(d_fake_data.t()))
        d_fake_error = criterion(d_fake_decision, Variable(torch.zeros([1,1])))  # zeros = fake
        d_fake_error.backward()
        d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()

        dre, dfe = extract(d_real_error)[0], extract(d_fake_error)[0]

    for g_index in range(g_steps):
        # 2. Train G on D's response (but DO NOT train D on these labels)
        G.zero_grad()

        gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        g_fake_data = G(gen_input)
        dg_fake_decision = D(preprocess(g_fake_data.t()))
        g_error = criterion(dg_fake_decision, Variable(torch.ones([1,1])))  # Train G to pretend it's genuine

        g_error.backward()
        g_optimizer.step()  # Only optimizes G's parameters
        ge = extract(g_error)[0]

即使您以前从未看过PyTorch,也可以判断发生了什么。在第一(d_index)部分中,我们将两种类型的数据都通过D,然后对D的猜测与实际标签应用可区分的标准。推动是“前进”的一步;然后,我们明确调用“ backward()”以计算梯度,然后将其用于在d_optimizer step()调用中更新D的参数。使用G,但此处未进行训练。

然后,在最后一个(g_index)部分中,我们对G执行相同的操作-请注意,我们也通过D运行了G的输出(本质上是给伪造者提供了一个可以进行练习的侦探),但我们并未优化或更改D在这一步。我们不希望侦探D学习错误的标签。因此,我们仅调用g_optimizer.step()。

而且…仅此而已。还有其他一些样板代码,但GAN特定的东西只是这5个组件,仅此而已。

在D和G之间这种禁止的舞蹈进行了数千回合之后,我们会得到什么?判别器D很快变得很好(而G缓慢上升),但是一旦达到一定程度的力量,G便成为了一个有价值的对手并开始进步。真正提高。

超过5,000轮训练,每轮训练D 20次,然后训练G 20次,G的输出平均值超过4.0,但随后回到相当稳定的正确范围内(左)。同样,标准偏差最初会朝错误的方向下降,然后上升到所需的1.25范围(右),与R匹配。
在这里插入图片描述

好的,因此基本属性最终匹配R。更高的时刻怎么样?分布的形状看起来正确吗?毕竟,您当然可以具有均值为4.0且标准偏差为1.25的均匀分布,但这与R并不完全匹配。让我们看一下G发出的最终分布:
在这里插入图片描述
不错。右尾比左尾稍微胖一点,但是偏度和峰度是原始高斯分布的再现。

G几乎完美地恢复了原始分布R,而D则退缩在角落,喃喃自语,无法从小说中分辨出事实。这正是我们想要的行为(请参见古德费洛中的图1)。从少于50行的代码开始。

现在,警告一下:GAN可能会很挑剔。而且脆弱。而且当他们进入怪异状态时,他们常常会在没有一点哄骗的情况下出来。运行我的示例代码十次(每次超过5,000发)显示了以下十个分布:
在这里插入图片描述
10次测试中有8次获得了非常好的最终分布——近似于高斯分布,均值为4,标准差在正确的范围内。但是两次运行没有——在第5次运行中,有一个凹分布,平均值在6.0左右,在最后一次运行(第10次),在11处有一个狭窄的峰值!当你开始在几乎任何环境中应用GANs时,你会发现这个现象——GANs并不像一般的监督学习工作流那样稳定。但当它们发挥作用时,它们看起来几乎是神奇的。
去看看代码

#!/usr/bin/env python

# Generative Adversarial Networks (GAN) example in PyTorch. Tested with PyTorch 0.4.1, Python 3.6.7 (Nov 2018)
# See related blog post at https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f#.sch4xgsa9

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

matplotlib_is_available = True
try:
  from matplotlib import pyplot as plt
except ImportError:
  print("Will skip plotting; matplotlib is not available.")
  matplotlib_is_available = False

# Data params
data_mean = 4
data_stddev = 1.25

# ### Uncomment only one of these to define what data is actually sent to the Discriminator
#(name, preprocess, d_input_func) = ("Raw data", lambda data: data, lambda x: x)
#(name, preprocess, d_input_func) = ("Data and variances", lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2)
#(name, preprocess, d_input_func) = ("Data and diffs", lambda data: decorate_with_diffs(data, 1.0), lambda x: x * 2)
(name, preprocess, d_input_func) = ("Only 4 moments", lambda data: get_moments(data), lambda x: 4)

print("Using data [%s]" % (name))

# ##### DATA: Target data and generator input data

def get_distribution_sampler(mu, sigma):
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussian

def get_generator_input_sampler():
    return lambda m, n: torch.rand(m, n)  # Uniform-dist data into generator, _NOT_ Gaussian

# ##### MODELS: Generator model and discriminator model

class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, f):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.f = f

    def forward(self, x):
        x = self.map1(x)
        x = self.f(x)
        x = self.map2(x)
        x = self.f(x)
        x = self.map3(x)
        return x

class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, f):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.f = f

    def forward(self, x):
        x = self.f(self.map1(x))
        x = self.f(self.map2(x))
        return self.f(self.map3(x))

def extract(v):
    return v.data.storage().tolist()

def stats(d):
    return [np.mean(d), np.std(d)]

def get_moments(d):
    # Return the first 4 moments of the data provided
    mean = torch.mean(d)
    diffs = d - mean
    var = torch.mean(torch.pow(diffs, 2.0))
    std = torch.pow(var, 0.5)
    zscores = diffs / std
    skews = torch.mean(torch.pow(zscores, 3.0))
    kurtoses = torch.mean(torch.pow(zscores, 4.0)) - 3.0  # excess kurtosis, should be 0 for Gaussian
    final = torch.cat((mean.reshape(1,), std.reshape(1,), skews.reshape(1,), kurtoses.reshape(1,)))
    return final

def decorate_with_diffs(data, exponent, remove_raw_data=False):
    mean = torch.mean(data.data, 1, keepdim=True)
    mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])
    diffs = torch.pow(data - Variable(mean_broadcast), exponent)
    if remove_raw_data:
        return torch.cat([diffs], 1)
    else:
        return torch.cat([data, diffs], 1)

def train():
    # Model parameters
    g_input_size = 1      # Random noise dimension coming into generator, per output vector
    g_hidden_size = 5     # Generator complexity
    g_output_size = 1     # Size of generated output vector
    d_input_size = 500    # Minibatch size - cardinality of distributions
    d_hidden_size = 10    # Discriminator complexity
    d_output_size = 1     # Single dimension for 'real' vs. 'fake' classification
    minibatch_size = d_input_size

    d_learning_rate = 1e-3
    g_learning_rate = 1e-3
    sgd_momentum = 0.9

    num_epochs = 5000
    print_interval = 100
    d_steps = 20
    g_steps = 20

    dfe, dre, ge = 0, 0, 0
    d_real_data, d_fake_data, g_fake_data = None, None, None

    discriminator_activation_function = torch.sigmoid
    generator_activation_function = torch.tanh

    d_sampler = get_distribution_sampler(data_mean, data_stddev)
    gi_sampler = get_generator_input_sampler()
    G = Generator(input_size=g_input_size,
                  hidden_size=g_hidden_size,
                  output_size=g_output_size,
                  f=generator_activation_function)
    D = Discriminator(input_size=d_input_func(d_input_size),
                      hidden_size=d_hidden_size,
                      output_size=d_output_size,
                      f=discriminator_activation_function)
    criterion = nn.BCELoss()  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
    d_optimizer = optim.SGD(D.parameters(), lr=d_learning_rate, momentum=sgd_momentum)
    g_optimizer = optim.SGD(G.parameters(), lr=g_learning_rate, momentum=sgd_momentum)

    for epoch in range(num_epochs):
        for d_index in range(d_steps):
            # 1. Train D on real+fake
            D.zero_grad()

            #  1A: Train D on real
            d_real_data = Variable(d_sampler(d_input_size))
            d_real_decision = D(preprocess(d_real_data))
            d_real_error = criterion(d_real_decision, Variable(torch.ones([1,1])))  # ones = true
            d_real_error.backward() # compute/store gradients, but don't change params

            #  1B: Train D on fake
            d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
            d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
            d_fake_decision = D(preprocess(d_fake_data.t()))
            d_fake_error = criterion(d_fake_decision, Variable(torch.zeros([1,1])))  # zeros = fake
            d_fake_error.backward()
            d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()

            dre, dfe = extract(d_real_error)[0], extract(d_fake_error)[0]

        for g_index in range(g_steps):
            # 2. Train G on D's response (but DO NOT train D on these labels)
            G.zero_grad()

            gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
            g_fake_data = G(gen_input)
            dg_fake_decision = D(preprocess(g_fake_data.t()))
            g_error = criterion(dg_fake_decision, Variable(torch.ones([1,1])))  # Train G to pretend it's genuine

            g_error.backward()
            g_optimizer.step()  # Only optimizes G's parameters
            ge = extract(g_error)[0]

        if epoch % print_interval == 0:
            print("Epoch %s: D (%s real_err, %s fake_err) G (%s err); Real Dist (%s),  Fake Dist (%s) " %
                  (epoch, dre, dfe, ge, stats(extract(d_real_data)), stats(extract(d_fake_data))))

    if matplotlib_is_available:
        print("Plotting the generated distribution...")
        values = extract(g_fake_data)
        print(" Values: %s" % (str(values)))
        plt.hist(values, bins=50)
        plt.xlabel('Value')
        plt.ylabel('Count')
        plt.title('Histogram of Generated Distribution')
        plt.grid(True)
        plt.show()


train()

在这里插入图片描述

  • 5
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值