目录
一、原理
1.1 GAN简单介绍
GAN(Generative Adversarial Network)主要是作为一种生成模型被广泛使用,它其实包含了两个模型,一个是生成模型(Generative Model),一个是判别模型(Discriminative Model),即生成器和判别器。GAN利用两者相互竞争来学习目标(数据)的分布,生成器会尝试欺骗判别器,让它认为生成的样本是真实的;判别器会尝试区分真实的样本和生成的样本。具体流程如下所示:
GAN的目标函数如下:
训练过程中固定一方,更新另一个网络的参数,交替迭代。但是原始的GAN训练有着两个显著的缺陷:难以训练离散数据以及训练困难。而BGAN可以很好的解决这两点。
1.2 Boundary Seeking原理
Boundary Seeking是一种训练GAN的方式,它让生成器不直接依赖于判别器的输出,而是去寻找一个目标分布的边界,这个目标分布在理想情况下会和数据分布一致。这样做有两个好处:一是可以处理离散数据,比如文本或图像;二是可以避免GAN训练过程中出现的不稳定性或模式崩溃。
我们可以把目标分布的边界想象成一个圆形的围栏,里面有很多真实数据,比如二进制序列。生成器要尽量产生一些靠近围栏的样本,也就是说和真实数据很相似的样本。这样判别器就很难发现生成器产生的样本和真实数据之间的区别。如果生成器产生一些远离围栏的样本,比如非二进制序列,那么判别器就很容易识别出来,并给出一个很低的得分。这个得分就是生成器要优化的目标函数。
生成的数据如果在围栏的中心,也是和真实的数据很相似,但是这样的话,生成器就没有办法探索更多的可能性。因为在围栏的中心,生成器产生的样本和真实数据之间的距离都很小,判别器给出的得分都很高,生成器就没有梯度来更新参数。而如果生成器产生一些靠近边界的样本,那么判别器给出的得分就会有一定的变化,生成器就可以根据这个变化来调整参数。这样生成器就可以学习到更多的数据特征,并且避免了模式崩溃(mode collapse)。
1.2 BGAN原理
BGAN采用Boundary Seeking的方法对GAN进行训练,引入策略梯度(Policy Gradient)来解决离散值导致价值函数不是处处可微的问题。引入策略梯度后GAN不再直接根据是否骗过判别网络调整生成网络,而是间接基于判别网络的评价计算目标,可以提高训练的稳定度。
原始GAN论文中表示,最优的判别器为:
因此,如果我们知道每个生成器对应的最优判别器就可以重新整理上面的方程,最终变成下面这样:
从这个方程我们可以看出,即使我们没有得到最优的生成器G,仍然可以通过调整、生成器的分布、生成器与判别器的比例,得到真实数据的分布。虽然我们很难得到最优的判别器,但是,我们可以通过不断地训练来迫近它,我们的训练效果也将越来越好。
如果我们训练出来的生成器足够完美,那么将无限接近于,判别器将无法判断生成样本和真实样本之间的区别,即。因此最优的生成器就是能使判别器处处都为0.5的那个。这个便是我们要找的决策边界,也就是上面提到的基于判别网络的评价计算目标。这样的话,我们可以调整生成器的目标函数,使得判别器的输出都为0.5。新的生成器目标函数如下:
其目标函数的目的是减少于之间的距离,即使。
二、算法实现
- models
- BGAN.py
- __init__.py
- data
- mnist
- train.py
BGAN.py
import numpy as np
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, latent_dim, image):
super(Generator, self).__init__()
self.latent_dim = latent_dim
self.image = image
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(self.latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(self.image))),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.shape[0], self.image[0], self.image[1], self.image[2])
return img
class Discriminator(nn.Module):
def __init__(self, image):
super(Discriminator, self).__init__()
self.image = image
self.model = nn.Sequential(
nn.Linear(int(np.prod(self.image)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, img):
img_flat = img.view(img.shape[0], -1)
validity = self.model(img_flat)
return validity
train.py
import os
import argparse
import torch
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
from models.BGAN import Generator, Discriminator
os.makedirs("images", exist_ok=True)
def parser_args():
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image samples")
args = parser.parse_args()
return args
def boundary_seeking_loss(y_pred):
"""
Boundary seeking loss.
"""
return 0.5 * torch.mean((torch.log(y_pred) - torch.log(1 - y_pred)) ** 2)
def train(gen, disc, disc_loss, device, dataloader, optim_G, optim_D, n_epochs, latent_dim, sample_interval):
gen.to(device)
disc.to(device)
disc_loss.to(device)
tensor = torch.cuda.FloatTensor
for epoch in range(n_epochs):
for i, (img, _) in enumerate(dataloader):
# Adversarial ground truths
valid = Variable(tensor(img.shape[0], 1).fill_(1.0), requires_grad=False)
fake = Variable(tensor(img.shape[0], 1).fill_(0.0), requires_grad=False)
# Configure input
real_img = Variable(img.type(tensor))
# -----------------
# Train Generator
# -----------------
optim_G.zero_grad()
# Sample noise as generator input
z = Variable(tensor(np.random.normal(0, 1, (img.shape[0], latent_dim))))
# Generate a batch of images
gen_img = gen(z)
# Loss measures generator's ability to fool the discriminator
g_loss = boundary_seeking_loss(disc(gen_img))
g_loss.backward()
optim_G.step()
# ---------------------
# Train Discriminator
# ---------------------
optim_D.zero_grad()
# Measure discriminator's ability to classify real from generated samples
real_loss = disc_loss(disc(real_img), valid)
fake_loss = disc_loss(disc(gen_img.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optim_D.step()
if i % 100 == 0:
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
)
batches_done = epoch * len(dataloader) + i
if batches_done % sample_interval == 0:
save_image(gen_img.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
def main():
args = parser_args()
img_shape = (args.channels, args.img_size, args.img_size)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Initialize generator and discriminator
gen = Generator(args.latent_dim, img_shape)
disc = Discriminator(img_shape)
disc_loss = torch.nn.BCELoss()
# Configure data loader
os.makedirs("./data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
datasets.MNIST(
"./data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(args.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=args.batch_size,
shuffle=True,
)
# Optimizers
optim_g = torch.optim.Adam(gen.parameters(), lr=args.lr, betas=(args.b1, args.b2))
optim_d = torch.optim.Adam(disc.parameters(), lr=args.lr, betas=(args.b1, args.b2))
train(gen, disc, disc_loss, device, dataloader, optim_g, optim_d, args.n_epochs, args.latent_dim,
args.sample_interval)
return
if __name__ == '__main__':
main()
训练结果: