原文链接:https://arxiv.org/pdf/1609.03126.pdf
简介
背景:此文又回到了GAN绕不过的难题上,即如何稳定优化,本篇文章与之前“WGAN”和“W系列GAN”不同,它并不是通过设计目标函数来达到此目的,而是另辟蹊径通过改变判别器的结构来实现的。
核心思想:将判别器改为一个能量函数,使得数据在流行分布附近时能量低,在其他地方时能量高。
由上图可知,判别器的结构发生了改变,不似以往单一的神经网络结构,而是由一对编码器与解码器组成,且能量的输出是基于整个编解码结构来构造的。这样设计最直接的优点,就是LOSS函数我们无需自己绞尽脑汁地自己去特定地构造,而可以直接使用已经成熟的LOSS,仅需套上此框架让其输出能量来作为优化目标就行。
基础结构
基础概念
纳什均衡
又称为非合作博弈均衡,在一个博弈过程中,无论对方的策略选择如何,当事人一方都会选择某个确定的策略,则该策略被称作支配性策略。如果两个博弈的当事人的策略组合分别构成各自的支配性策略,那么这个组合就被定义为纳什平衡。一个策略组合被称为纳什平衡,当每个博弈者的平衡策略都是为了达到自己期望收益的最大值,与此同时,其他所有博弈者也遵循这样的策略。
那么在GAN中,当判别器与生成器的合作博弈过程接近纳什均衡时,我们就会陷入优化的瓶颈。
目标函数
其中,此处的m为边距需满足。那么上面的两个LOSS函数,实际上就是在判别器的优化过程中加上了一个限制项,此限制项在生成数据与x分布过于靠近时拉大判别器LOSS,加快判别器的优化过程;但是,当生成数据与x分布过于远离时,此时限制项限制住生成数据通过判别器时产生的LOSS,即让判别器的优化过程先停下来,等待生成器优化将生成数据的分布拉向x。
训练判别器时最小化V,训练生成器时最小化U,G与D组成一对纳什均衡,那么满足:
D代表最优判别器,G代表最优生成器,以此我们可以确定两个优化的上界。
因为构造的关系,,则达到最小值时,有
因为第二项中两个因子必定是一正一负的,所以积分后的值必定是在[-1,0]之间,所以最大值为m,即。于是,又因为
则,最终得,即,这种情况产生时
此项为零,即,达到我们的优化目标。
编解码结构
训练自动编码器的一个常见问题是,模型可能学到的不是一个恒等函数,这意味着它可能将整个空间赋值为0能量。为了避免这个问题,必须强制模型给数据流形之外的点提供更高的能量。这种规范器旨在限制自动编码器的重构能力,使得它只能将低能量归入较小部分的输入点。
EBGAN框架中的能量函数(判别器)也被看作是通过产生对比性的样本的发生器来规范化的,判别器应该给予对比性的样本赋予高的重构能量。从这个角度来看,EBGAN框架允许更多的灵活性,因为:(i)规范器(生成器)完全可以训练而不是人工指定;(2)对抗训练模式使产生有对比性的样本与学习能量函数两个目标之间可直接相互作用。
D 的自动编码器的选择乍一看似乎是任意的,但作者的设定使得它比二元分类网络更有吸引力:
(1)基于重建的输出不是使用单个目标信息来训练模型,而是为判别器提供多样化的目标。由于二元分类网络,只有两个目标是可能的,所以在一个小批次内,对应于不同样本的梯度最有可能远离正交,这导致了低效率的训练,并且当前的硬件通常不提供减少小批量的尺寸的选择。另一方面,重建损失可能会在批次内产生非常不同的梯度方向,允许更大的批量大小而不损失效率。
(2)传统上使用自动编码器来表示基于能量的模型。当用正则化训练时, 自动编码器可以在无监督或反例的情况下学习能量流形。这意味着,当EBGAN自动编码模型被训练以重构真实样本时,判别器也有助于发现数据流形。相反,如果没有来自生成器的负面例子,用二元分类损失训练的判别器变得毫无意义。
上面这一段摘自文章的翻译,可以总结为引入编解码结构,使得生成数据可以具有更多的多样性,同时由于使用了编解码,则原来简单的判断真假也需要改编为编码前后的损失,文中使用的MSE。
代码与实践结果
参考链接:https://github.com/WingsofFAN/PyTorch-GAN/blob/master/implementations/ebgan/ebgan.py
import argparse
import os
import numpy as np
import math
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch
os.makedirs("images", exist_ok=True)
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=62, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="number of image channels")
opt = parser.parse_args()
print(opt)
img_shape = (opt.channels, opt.img_size, opt.img_size)
cuda = True if torch.cuda.is_available() else False
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.init_size = opt.img_size // 4
self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))
self.conv_blocks = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
nn.Tanh(),
)
def forward(self, noise):
out = self.l1(noise)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.conv_blocks(out)
return img
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
# Upsampling
self.down = nn.Sequential(nn.Conv2d(opt.channels, 64, 3, 2, 1), nn.ReLU())
# Fully-connected layers
self.down_size = opt.img_size // 2
down_dim = 64 * (opt.img_size // 2) ** 2
self.embedding = nn.Linear(down_dim, 32)
self.fc = nn.Sequential(
nn.BatchNorm1d(32, 0.8),
nn.ReLU(inplace=True),
nn.Linear(32, down_dim),
nn.BatchNorm1d(down_dim),
nn.ReLU(inplace=True),
)
# Upsampling
self.up = nn.Sequential(nn.Upsample(scale_factor=2), nn.Conv2d(64, opt.channels, 3, 1, 1))
def forward(self, img):
out = self.down(img)
embedding = self.embedding(out.view(out.size(0), -1))
out = self.fc(embedding)
out = self.up(out.view(out.size(0), 64, self.down_size, self.down_size))
return out, embedding
# Reconstruction loss of AE
pixelwise_loss = nn.MSELoss()
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
if cuda:
generator.cuda()
discriminator.cuda()
pixelwise_loss.cuda()
# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)
# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
datasets.MNIST(
"../../data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=opt.batch_size,
shuffle=True,
)
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
def pullaway_loss(embeddings):
norm = torch.sqrt(torch.sum(embeddings ** 2, -1, keepdim=True))
normalized_emb = embeddings / norm
similarity = torch.matmul(normalized_emb, normalized_emb.transpose(1, 0))
batch_size = embeddings.size(0)
loss_pt = (torch.sum(similarity) - batch_size) / (batch_size * (batch_size - 1))
return loss_pt
# ----------
# Training
# ----------
# BEGAN hyper parameters
lambda_pt = 0.1
margin = max(1, opt.batch_size / 64.0)
for epoch in range(opt.n_epochs):
for i, (imgs, _) in enumerate(dataloader):
# Configure input
real_imgs = Variable(imgs.type(Tensor))
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad()
# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
# Generate a batch of images
gen_imgs = generator(z)
recon_imgs, img_embeddings = discriminator(gen_imgs)
# Loss measures generator's ability to fool the discriminator
g_loss = pixelwise_loss(recon_imgs, gen_imgs.detach()) + lambda_pt * pullaway_loss(img_embeddings)
g_loss.backward()
optimizer_G.step()
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Measure discriminator's ability to classify real from generated samples
real_recon, _ = discriminator(real_imgs)
fake_recon, _ = discriminator(gen_imgs.detach())
d_loss_real = pixelwise_loss(real_recon, real_imgs)
d_loss_fake = pixelwise_loss(fake_recon, gen_imgs.detach())
d_loss = d_loss_real
if (margin - d_loss_fake.data).item() > 0:
d_loss += margin - d_loss_fake
d_loss.backward()
optimizer_D.step()
# --------------
# Log Progress
# --------------
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
)
batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
mnist测试结果
可见训练的的结果并不理想,因为它只是证明此方法优化的稳定了,但是或许收敛速度并不快。