chapter1-GAN

0.引言

        最原始的GAN发表在NIPS2014上,GAN的思想是非常有趣的,以前我们在设计网络的时候,都是利用一个目标函数来指导一个网络的训练。然而GAN利用两种目标函数来指导两个结构(生成器与判别器)的博弈对抗,让两者互相促进。

        举个简单的对抗博弈进步的例子-假币与银行,在初期的时候,由于银行并没有认识到假币的存在,因此一些很简单的假币就可能会被银行误收;随着科技进步,开始有了水印 、紫外光辨别真伪,在这样的条件下,旧的假币就无法通过银行的检验了,但是旧的假币可以知道自己不能通过检验的原因在于水印等不合格,于是假币就在这些方面进行改进,以此想要继续骗过银行,因此银行和假币就处在不断地博弈进步之中 。

        那么现在我们来看一下GAN的工作原理,实际上GAN网络分为两部分-生成器判别器,生成器的目的就是造假,而判别器的目的就是判别真假,当生成器和判别器可以正确对抗博弈进化时,随着判别器越来越强,生成器也会越来越强,而做生成任务时,我们需要的就是生成器部分。

1.生成器

A.生成器的训练目标

        生成器比较复杂,首先我们并不知道怎样做生成这件事,因为我们并不知道用什么去生成,对于传统的神经网络,我们知道输入和输出,而对于生成任务来说,我们希望生成器可以给我们和训练数据属于同种的数据,简单来说就是,如果我们给定的数据是海边风景照,那么我们希望生成器可以给我们生成海边风景照,但是海边风景照数量非常庞大,因为涉及的因素非常多,比如天气、沙滩状况、涨潮情况,稍微变动一点就是不同的数据了,所以我们不可能收集所有的海边风景照。

        那么GAN是如何解决这个问题的呢,对于实际数据数量庞大而我们所拥有的数量有限问题,GAN的做法是从分布的角度去看待数据,即我们训练的GAN的目标不是直接生成数据,而是生成分布。那我们到底要怎么用GAN去拟合一个分布呢?并且输入应该是什么样子的呢?

B.生成器的输入

        为了避免直接做拟合分布这个0->1的问题,GAN的做法是将另一个分布通过生成器映射为我们需要的目标分布,由于分布都是无限的,因此我们可以通过映射将一个分布转换为另一个分布。说到分布,常用的就是高斯、均匀等,一般的做法都是使用高斯分布。

        现在我们确定了方案,即生成器的目标是将一个分布转换为目标分布。那么怎么做呢?假设我们要生成28*28*1的图像数据,我们将每一个点位置的数据视为从一个高斯分布中采样的数据,即通过高斯分布获得28*28*1的输入(称之为输入噪声),此时再通过网络对于采样得到的28*28*1进行转换,这看起来确实是一种方案,但实际上我们通常不会这么做,原因在于我们可能要生成很大的数据,假设我们要生成一个1024*1024的数据,如果我们在输入端的数据就这么大,那么整个网络是十分消耗资源的。我们通常的做法是采样出一个1*z_dim维度的向量,并且认为这个向量的含义是对目标生成数据的高级特征描述,简言之可以理解为肤色的深度、头发的浓密度等等。

C.生成器如何生成

        现在我们确定了输入是一个向量,我们要怎么做到生成呢,首先必要的就是需要进行高级特征的解耦,这个过程是我们以前所做的分类网络的逆过程,我们将输入向量理解为分类网络在判别层之前的特征向量,那么为了获得图片,我们需要向浅层解耦,即将高级特征解耦为低级特征,所以首先要做的就是升维,之后的步骤就是利用得到的低级特征向量进行生成。

2.判别器

        判别器的结构相对简单,主要进行判别真假的任务,可以参考在初学神经网络时候所写的分类网络,把分物体类别的问题转化为分物体是真实还是虚假的问题。

3.训练

A.损失函数

        GAN如何训练的问题很大程度上取决于损失函数的设计,我们先梳理一下,对于判别器来说,目标是要判别出真实数据(判别为1)和判别数据(判别为0),生成器的目的是生成让判别器误判的数据。我们这里先规定一下符号,生成器用G表示,判别器用D表示,输入的采样向量用z表示,生成数据为G(z),真实数据用x表示。

        对于判别器来说,希望判别真实数据的结果接近1,即D(x)->1,判别虚假数据的结果接近0,即D(G(z))->0,GAN设计的损失函数判别器部分如下:

loss_D=max[log(D(x))+log(1-D(G(z)))]

        对于生成器来说,希望判别器判别生成的结果接近1,即D(G(z))->1,GAN设计的损失函数生成器部分如下:

loss_G=min[log(1-D(G(z)))]

        合在一起就是GAN论文中写的损失函数了:

min_G\ max_D\ E_{x\in p_{data}}log(D(x))+E_{z\in p_{z}}log(1-D(G(z)))]

B.训练策略

        在GAN的训练中,有非常危险的问题就是博弈失败,我们知道判别器和生成器是互相促进的,但如果他们的训练不协调就会出现问题。

第一种情况:如果D的进化很快,D很容易区分生成数据与判别数据,那么此时训练D使用的损失函数值就很快接近最小值,梯度也急剧变小,由于G的梯度求解是依赖D的梯度的,当D的梯度很小的时候G的梯度也会变得很小,也就是G无法进化。

第二种情况:如果G的进化很快,G很快就可以欺骗过判别器,此时判别器无法区分真实数据与生成数据,也就不能提供梯度信息,导致G无法再进化,但是G可能并没有进化的很好。

        因此如何平衡生成器和判别器的训练是很重要的。

4.代码

import torch
from torch import nn, autograd, optim
import numpy as np
import visdom
import random
import matplotlib.pyplot as plt

h_dim = 400
bs = 512
viz = visdom.Visdom()

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(2, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, 2),
        )

    def forward(self, z):
        output = self.net(z)
        return output

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(2, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        output = self.net(x)
        return output.view(-1)


def data_generator():
    scale = 2
    centers = [
        (1, 0),
        (-1, 0),
        (0, 1),
        (0, -1),
        (1. / np.sqrt(2), 1. / np.sqrt(2)),
        (1. / np.sqrt(2), -1. / np.sqrt(2)),
        (-1. / np.sqrt(2), 1. / np.sqrt(2)),
        (-1. / np.sqrt(2), -1. / np.sqrt(2)),
    ]
    centers = [(scale * x, scale * y) for x, y in centers]

    while True:
        dataset = []
        for i in range(bs):
            point = np.random.randn(2) * 0.02
            center = random.choice(centers)
            point[0] += center[0]
            point[1] += center[1]

            dataset.append(point)

        dataset = np.array(dataset).astype(np.float32)
        dataset /= 1.414
        yield dataset

def main():
    torch.manual_seed(23)
    np.random.seed(23)

    data_iter = data_generator()
    x = next(data_iter)

    G = Generator().cuda()
    D = Discriminator().cuda()
    optim_G = optim.Adam(G.parameters(), lr=5e-4, betas=(0.5, 0.9))
    optim_D = optim.Adam(D.parameters(), lr=5e-4, betas=(0.5, 0.9))

    viz.line([[0, 0]], [0], win='loss', opts=dict(title='loss',legend=['D', 'G']))

    for epoch in range(50000):

        for _ in range(5):
            xr = next(data_iter)
            xr = torch.from_numpy(xr).cuda()
            predr = D(xr)
            # max predr, min lossr, min -predr
            lossr = -predr.mean()

            z = torch.randn(bs, 2).cuda()
            xf = G(z).detach()
            predf = D(xf).requires_grad_()
            lossf = predf.mean()

            loss_D = lossr + lossf

            optim_D.zero_grad()
            loss_D.backward()
            optim_D.step()

        z = torch.randn(bs, 2).cuda()
        xf = G(z)
        predf = D(xf)
        # max predf
        loss_G = -predf.mean()

        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()

        if epoch % 100 == 0:
            viz.line([[loss_D.item(), loss_G.item()]], [epoch], win='loss', update='append')

            print(loss_D.item(), loss_G.item())

            generate_image(D, G, xr.cpu().numpy(), epoch)

def generate_image(D, G, xr, epoch):
    N_POINTS = 128
    RANGE = 3
    plt.clf()

    points = np.zeros((N_POINTS, N_POINTS, 2), dtype="float32")
    points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]
    points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]
    points = points.reshape((-1, 2))

    with torch.no_grad():
        points = torch.Tensor(points).cuda()
        disc_map = D(points).cpu().numpy()

    x = y = np.linspace(-RANGE, RANGE, N_POINTS)
    cs = plt.contour(x, y, disc_map.reshape((len(x), len(y))).transpose())
    plt.clabel(cs, inline=1, fontsize=10)

    with torch.no_grad():
        z = torch.randn(bs, 2).cuda()
        samples = G(z).cpu().numpy()
    plt.scatter(xr[:, 0], xr[:, 1], c='orange', marker='.')
    plt.scatter(samples[:, 0], samples[:, 1], c='green', marker='+')

    viz.matplot(plt, win='contour', opts=dict(title='p(x):%d'%epoch))

if __name__ == '__main__':
    main()

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值