AI算法Python实现:BGAN(Boundary Seeking GAN)

目录

一、原理

1.1 GAN简单介绍

1.2 Boundary Seeking原理

1.2 BGAN原理

二、算法实现

三、参考资料


一、原理

1.1 GAN简单介绍

GAN(Generative Adversarial Network)主要是作为一种生成模型被广泛使用,它其实包含了两个模型,一个是生成模型(Generative Model),一个是判别模型(Discriminative Model),即生成器和判别器。GAN利用两者相互竞争来学习目标(数据)的分布,生成器会尝试欺骗判别器,让它认为生成的样本是真实的;判别器会尝试区分真实的样本和生成的样本。具体流程如下所示:

GAN的目标函数如下:

 \underset{G}{min}\underset{D}{max}V\left ( D, G \right )=E_{x\sim p_{data}\left ( x \right )}\left [ log D\left ( x \right ) \right ]+E_{z\sim p_{z}\left ( z \right )}\left [ log\left ( 1-D\left ( G\left ( z \right ) \right ) \right ) \right ]

训练过程中固定一方,更新另一个网络的参数,交替迭代。但是原始的GAN训练有着两个显著的缺陷:难以训练离散数据以及训练困难。而BGAN可以很好的解决这两点。

1.2 Boundary Seeking原理

Boundary Seeking是一种训练GAN的方式,它让生成器不直接依赖于判别器的输出,而是去寻找一个目标分布的边界,这个目标分布在理想情况下会和数据分布一致。这样做有两个好处:一是可以处理离散数据,比如文本或图像;二是可以避免GAN训练过程中出现的不稳定性或模式崩溃。

我们可以把目标分布的边界想象成一个圆形的围栏,里面有很多真实数据,比如二进制序列。生成器要尽量产生一些靠近围栏的样本,也就是说和真实数据很相似的样本。这样判别器就很难发现生成器产生的样本和真实数据之间的区别。如果生成器产生一些远离围栏的样本,比如非二进制序列,那么判别器就很容易识别出来,并给出一个很低的得分。这个得分就是生成器要优化的目标函数。

生成的数据如果在围栏的中心,也是和真实的数据很相似,但是这样的话,生成器就没有办法探索更多的可能性。因为在围栏的中心,生成器产生的样本和真实数据之间的距离都很小,判别器给出的得分都很高,生成器就没有梯度来更新参数。而如果生成器产生一些靠近边界的样本,那么判别器给出的得分就会有一定的变化,生成器就可以根据这个变化来调整参数。这样生成器就可以学习到更多的数据特征,并且避免了模式崩溃(mode collapse)。

1.2 BGAN原理

BGAN采用Boundary Seeking的方法对GAN进行训练,引入策略梯度(Policy Gradient)来解决离散值导致价值函数不是处处可微的问题。引入策略梯度后GAN不再直接根据是否骗过判别网络调整生成网络,而是间接基于判别网络的评价计算目标,可以提高训练的稳定度。

原始GAN论文中表示,最优的判别器为:

D_{G}^{*}(x)=\frac{p_{data}(x)}{p_{data}(x)+p_{g}(x)}

因此,如果我们知道每个生成器对应的最优判别器就可以重新整理上面的方程,最终变成下面这样:

p_{data}(x)=p_{g}(x)\frac{D_{G}^{*}(x)}{1-D_{G}^{*}(x)}

从这个方程我们可以看出,即使我们没有得到最优的生成器G,仍然可以通过调整p_{g}(x)、生成器的分布、生成器与判别器的比例,得到真实数据的分布。虽然我们很难得到最优的判别器,但是,我们可以通过不断地训练D(x)来迫近它,我们的训练效果也将越来越好。

如果我们训练出来的生成器足够完美,那么p_{g}(x)将无限接近于p_{data}(x),判别器将无法判断生成样本和真实样本之间的区别,即D(x)=0.5。因此最优的生成器就是能使判别器处处都为0.5的那个。这个D(x)=0.5便是我们要找的决策边界,也就是上面提到的基于判别网络的评价计算目标。这样的话,我们可以调整生成器的目标函数,使得判别器的输出都为0.5。新的生成器目标函数如下:

\underset{G}{min}E_{x\sim p_{G}(x)}[0.5(log(D(x)-log(1-D(x))))^{2}]

 其目标函数的目的是减少D(x)1-D(x)之间的距离,即使D(x)=0.5

二、算法实现

  • models
    • BGAN.py
    • __init__.py
  • data
    • mnist
  • train.py

BGAN.py

import numpy as np
import torch.nn as nn


class Generator(nn.Module):
    def __init__(self, latent_dim, image):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.image = image

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(self.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(self.image))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], self.image[0], self.image[1], self.image[2])
        return img


class Discriminator(nn.Module):
    def __init__(self, image):
        super(Discriminator, self).__init__()
        self.image = image

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(self.image)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity

train.py

import os
import argparse
import torch
import numpy as np
import torchvision.transforms as transforms

from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
from models.BGAN import Generator, Discriminator


os.makedirs("images", exist_ok=True)


def parser_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
    parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
    parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
    parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
    parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
    parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
    parser.add_argument("--channels", type=int, default=1, help="number of image channels")
    parser.add_argument("--sample_interval", type=int, default=400, help="interval between image samples")
    args = parser.parse_args()

    return args


def boundary_seeking_loss(y_pred):
    """
    Boundary seeking loss.
    """
    return 0.5 * torch.mean((torch.log(y_pred) - torch.log(1 - y_pred)) ** 2)


def train(gen, disc, disc_loss, device, dataloader, optim_G, optim_D, n_epochs, latent_dim, sample_interval):

    gen.to(device)
    disc.to(device)
    disc_loss.to(device)

    tensor = torch.cuda.FloatTensor

    for epoch in range(n_epochs):
        for i, (img, _) in enumerate(dataloader):

            # Adversarial ground truths
            valid = Variable(tensor(img.shape[0], 1).fill_(1.0), requires_grad=False)
            fake = Variable(tensor(img.shape[0], 1).fill_(0.0), requires_grad=False)

            # Configure input
            real_img = Variable(img.type(tensor))

            # -----------------
            #  Train Generator
            # -----------------

            optim_G.zero_grad()

            # Sample noise as generator input
            z = Variable(tensor(np.random.normal(0, 1, (img.shape[0], latent_dim))))

            # Generate a batch of images
            gen_img = gen(z)

            # Loss measures generator's ability to fool the discriminator
            g_loss = boundary_seeking_loss(disc(gen_img))

            g_loss.backward()
            optim_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optim_D.zero_grad()

            # Measure discriminator's ability to classify real from generated samples
            real_loss = disc_loss(disc(real_img), valid)
            fake_loss = disc_loss(disc(gen_img.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optim_D.step()

            if i % 100 == 0:
                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                    % (epoch, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
                )

            batches_done = epoch * len(dataloader) + i
            if batches_done % sample_interval == 0:
                save_image(gen_img.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)


def main():
    args = parser_args()
    img_shape = (args.channels, args.img_size, args.img_size)
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # Initialize generator and discriminator
    gen = Generator(args.latent_dim, img_shape)
    disc = Discriminator(img_shape)
    disc_loss = torch.nn.BCELoss()

    # Configure data loader
    os.makedirs("./data/mnist", exist_ok=True)
    dataloader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "./data/mnist",
            train=True,
            download=True,
            transform=transforms.Compose(
                [transforms.Resize(args.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
            ),
        ),
        batch_size=args.batch_size,
        shuffle=True,
    )

    # Optimizers
    optim_g = torch.optim.Adam(gen.parameters(), lr=args.lr, betas=(args.b1, args.b2))
    optim_d = torch.optim.Adam(disc.parameters(), lr=args.lr, betas=(args.b1, args.b2))

    train(gen, disc, disc_loss, device, dataloader, optim_g, optim_d, args.n_epochs, args.latent_dim,
          args.sample_interval)

    return


if __name__ == '__main__':
    main()

训练结果:

三、参考资料

1. 原论文地址

2. BGAN

3. 生成对抗网络(GAN)

4. PyTorch Implementation of Boundary Seeking GAN

5. BGAN:支持离散值、提升训练稳定性的新GAN训练方法

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值