生成对抗网络实现篇(GAN)

这篇主要是分析如何根据生成对抗网络原理篇(GAN)

根据原理篇知道,GAN由两个部件组成,生成器和鉴别器。而且二者优化目标也不同,但是最后都是通过鉴别器的输出来计算的。

注意本文将优化目标与损失函数做区分,优化目标和损失函数都是为了达到相同的目的,但是优化目标更加偏向于数学一点,而损失函数是在真正实现GAN时的具体做法。下面会分析。

鉴别器的优化目标

鉴别器的优化目标是让(1)式最大化:
E x ∼ p d a t a [ l o g D ( x ) ] + E z ∼ p z [ l o g ( 1 − D ( G ( z ; θ g ) ) ) ] (1) \tag{1} E_{x \sim p_{data}}[logD(x)] +E_{z \sim p_{z}}[log(1-D(G(z;\theta_g)))] Expdata[logD(x)]+Ezpz[log(1D(G(z;θg)))](1)
使(1)式最大存在两个问题,

  • 实现时如何求解积分?(最原始的GAN只支持连续数据,所以期望就是积分)
  • 上式是最大化,如何最大化?

先看第一个问题,实践中,我们是没有办法利用积分求这两个数学期望的,具体实现时常用的做法就是采样(其实就是你获取到的训练集的样本,可以看做是从 p d a t a p_{data} pdata中采样得到的,生成数据也同理),我们可以用(2)式近似代替(1)式:

1 m ∑ i m l o g D ( x ( i ) ) + 1 m ∑ i = 1 m l o g ( 1 − D ( G ( z ( i ) ; θ g ) ) ) ] (2) \tag{2} {1 \over m}\sum_{i}^{m}logD(x^{(i)}) + {1 \over m} \sum_{i=1}^{m} log(1-D(G(z^{(i)};\theta_g)))] m1imlogD(x(i))+m1i=1mlog(1D(G(z(i);θg)))](2)
至于为啥可以这样写,涉及到蒙特卡洛估计的一些知识。可以点这采样与期望

第二个问题,通常在深度学习中,最大化问题都转化为最小化问题来解决,因为这样可以使用梯度下降算法更新参数(比如很多都是通过极大似然的思想,最后转化为最小化负对数似然)。所以最大化(2)式就是最小化(3)式:

− 1 m ∑ i m l o g D ( x ( i ) ) − 1 m ∑ i = 1 m l o g ( 1 − D ( G ( z ( i ) ; θ g ) ) ) ] (3) \tag{3} -{1 \over m}\sum_{i}^{m}logD(x^{(i)}) - {1 \over m} \sum_{i=1}^{m} log(1-D(G(z^{(i)};\theta_g)))] m1imlogD(x(i))m1i=1mlog(1D(G(z(i);θg)))](3)

生成器的优化目标

生成器的优化目标是最小化(4)式:
E z ∼ p z [ l o g ( 1 − D ( G ( z ; θ g ) ) ) ] (4) \tag{4}E_{z \sim p_{z}}[log(1-D(G(z;\theta_g)))] Ezpz[log(1D(G(z;θg)))](4)
也就是最大化(5)式:
E z ∼ p z [ l o g ( D ( G ( z ; θ g ) ) ) ] (5) \tag{5}E_{z \sim p_{z}}[log(D(G(z;\theta_g)))] Ezpz[log(D(G(z;θg)))](5)

通过对鉴别器的分析,同理对于生成器就是最小化(6)式
− 1 m ∑ i = 1 m l o g ( D ( G ( z ( i ) ; θ g ) ) ) ] (6) \tag{6}-{1 \over m} \sum_{i=1}^{m} log(D(G(z^{(i)};\theta_g)))] m1i=1mlog(D(G(z(i);θg)))](6)

损失函数

其实(3)式和(6)式就分别是鉴别器和生成器的损失函数了,是不是长的和交叉熵有点像。

生产器和鉴别器最后的损失函数都是通过鉴别器的输出完成计算的。鉴别器的工作就是判别样本是真实的样本还是假的样本(也就是生成器生成的),所以鉴别器相当于一个二分类器。对于二分类问题,损失函数是:
∑ i = 1 m − t l o g p ( x ( i ) ) − ( 1 − t ) l o g ( 1 − p ( x ( i ) ) ) (7) \tag{7} \sum_{i=1}^{m}-tlogp(x^{(i)})-(1-t)log(1-p(x^{(i)})) i=1mtlogp(x(i))(1t)log(1p(x(i)))(7)

t t t表示标签,只有两种取值 0 或 者 1 0或者1 01 p ( x ( i ) ) p(x^{(i)}) p(x(i))是分类器对第 i i i个样本的预测结果,大于 0.5 0.5 0.5说明将其预测为标签为 1 1 1的一类,小于 0.5 0.5 0.5说明将其预测为标签为 0 0 0的一类。

在pytorch中,(7)式已经实现好了,可以用nn.BCEWithLogitsLoss(),也可以用nn.BCELoss()+nn.Sigmod()组合的方式。

鉴别器的损失函数

对于真实数据,其标签为1,将 t = 1 t=1 t=1带入(7)式得到:
∑ i = 1 m − l o g p ( x ( i ) ) \sum_{i=1}^{m}-logp(x^{(i)}) i=1mlogp(x(i))
这正是(3)式的第一项。

对于假数据,其标签为0,将 t = 0 t=0 t=0带入(7)式得到:
∑ i = 1 m − l o g ( 1 − p ( x ( i ) ) ) \sum_{i=1}^{m}-log(1-p(x^{(i)})) i=1mlog(1p(x(i)))
这正是(3)式的第二项。

假设 m m m是batch的大小,那么(3)式就是先用一个batch的真实数据(标签为1)训练鉴别器,再用一个batch的假数据(标签为0)训练鉴别器。

所以鉴别器的损失函数就是两个nn.BCEWithLogitsLoss()的和。

生成器的损失函数

对于生成器,其目的是迷惑鉴别器,所以它生成的样本的标签也是1。将 t = 1 t=1 t=1带入(7)式得:
∑ i = 1 m − l o g p ( x ( i ) ) \sum_{i=1}^{m}-logp(x^{(i)}) i=1mlogp(x(i))

这正是(6)式。所以生成器的损失函数就是一个nn.BCEWithLogitsLoss()

最后一个问题就是生成器的噪声从哪来?常用的就是使用高斯噪声,当然也可以使用其他的,读者可以自己多试试。

所以最后GAN的出现了三个交叉熵。

下面是pytorch在minist数据集上的简单实现:

import torch
from torch import nn
import torchvision.transforms as tfs
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.utils import save_image
import os


os.makedirs("test_images", exist_ok=True)
os.makedirs("save", exist_ok=True)


def preprocess_img(x):
    x = tfs.ToTensor()(x)  # x (0., 1.)
    return (x - 0.5) / 0.5  # x (-1., 1.)


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
                    nn.Linear(784, 256),
                    nn.LeakyReLU(),
                    nn.Linear(256, 256),
                    nn.LeakyReLU(),
                    nn.Linear(256, 1),
                    )

    def forward(self, data):
        logits = self.net(data)
        return logits


class Generator(nn.Module):
    def __init__(self, noise_dim):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(noise_dim, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 784),
            nn.Tanh(),
            )

    def forward(self, noise, only_compute_logits=True):
        logits  = self.net(noise)
        return logits


def compute_loss(logits, labels):
    bce_loss = nn.BCEWithLogitsLoss()     # sigmod and bceloss
    loss = bce_loss(logits, labels)
    return loss


# 使用 adam 来进行训练,beta1 是 0.5, beta2 是 0.999
def get_optimizer(net, LearningRate):
    optimizer = torch.optim.Adam(net.parameters(), lr=LearningRate, betas=(0.5, 0.999))
    return optimizer


def train_gan(D_net, G_net, D_optimizer, G_optimizer,
                noise_size, num_epochs):

    D_net = D_net.to(device)
    G_net = G_net.to(device)
    for epoch in range(num_epochs):
        for iteration, (x, _) in enumerate(train_data):
            batch = x.size(0)
            real_labels = torch.ones(batch, 1).to(device)
            fake_labels = torch.zeros(batch, 1).to(device)

            ####  train discriminator network  ####
            real_data = x.view(batch, -1).to(device)      # real data
            real_logits = D_net(real_data)
            real_data_loss = compute_loss(real_logits, real_labels)  # loss of real data on discriminator

            noise = torch.randn(batch, noise_size).to(device)    # generate nois 
            fake_images = G_net(noise).detach()  # generate fake data
            fake_logits = D_net(fake_images)
            fake_data_loss = compute_loss(fake_logits, fake_labels)  # loss of fake data on discriminator

            discriminator_loss = real_data_loss + fake_data_loss
            D_optimizer.zero_grad()
            discriminator_loss.backward()
            D_optimizer.step()

            ####  train generator network  ####
            noise = torch.randn(batch, noise_size).to(device)
            fake_images = G_net(noise)
            fake_logits = D_net(fake_images)
            # for generator, it want to cheat the discriminator, so its lables is real
            generator_loss = compute_loss(fake_logits, real_labels)  # loss of generator network
            G_optimizer.zero_grad()
            generator_loss.backward()
            G_optimizer.step()

            if iteration % 20 == 0:
                print('Epoch: {:2d} | Iter: {:<4d} | D: {:.4f} | G:{:.4f}'.format(epoch,
                                                                                  iteration,
                                                                                  discriminator_loss.cpu().detach().numpy(),
                                                                                  generator_loss.cpu().detach().numpy()))
            if iteration % 100 == 0:
                fake_images = fake_images.view(fake_images.size(0), -1, 28, 28)   # the second dim is channels of picture
                save_image(fake_images.data[:25], "test_images/%d_%d.png" % (epoch, iteration), nrow=5, normalize=True)

        torch.save(G_net, 'save/generator.pt')

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

if __name__ == '__main__':
    EPOCH = 50
    BATCH_SIZE = 128
    NOISE_DIM = 100
    train_set = MNIST(root='/data/mnist/',
                      train=True,
                      download=True,
                      transform=preprocess_img)
    train_data = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)

    D = Discriminator()
    G = Generator(NOISE_DIM)

    D_optim = get_optimizer(D, 5e-4)
    G_optim = get_optimizer(G, 5e-4)

    train_gan(D, G, D_optim, G_optim,  NOISE_DIM, EPOCH)


如果你想测试训练好的生成器生成图片的效果,可以使用下面的代码:

import torch
from torchvision.utils import save_image
from main import Generator

batch = 128
noise_size = 100
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


G_net = torch.load('save/generator.pt')

noise = torch.randn(batch, noise_size).to(device)

fake_images = G_net(noise)

fake_images = fake_images.view(fake_images.size(0), -1, 28, 28)   # the second dim is channels of picture
save_image(fake_images.data[:50], "test_images/gen.png", nrow=10, normalize=True)

生成的效果:
在这里插入图片描述

  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值