WGAN学习笔记(从理论分析到Pytorch代码实践)

原始的GAN存在的问题

原始的GAN的目标函数实际上可以表现为
− E x ∼ P r [ l o g D ( x ) ] − E x ∼ P g [ l o g ( 1 − D ( x ) ) ] -E_{x\sim{P_r}}[logD(x)] - E_{x\sim{P_g}}[log(1-D(x))] ExPr[logD(x)]ExPg[log(1D(x))]
这里 P r P_r Pr是真实的样本分布, P g P_g Pg是生成器产生的样本分布。
WGAN \textbf{WGAN} WGAN对于这个目标函数从训练过程中进行了分析,因为我们在训练过程中,是固定G训练D,固定D训练G交替进行的,从上述公式进行分析,当G固定时对于一个样本x,那么代入上述公式可以得到
− P r ( x ) l o g D ( x ) − P g ( x ) l o g [ 1 − D ( x ) ] -P_r(x)logD(x) - P_g(x)log[1-D(x)] Pr(x)logD(x)Pg(x)log[1D(x)]
此时要求得最优的D,需要对D按照x进行求导,并令其导数为0
− P r ( x ) D ( x ) + P g ( x ) 1 − D ( x ) = 0 -\frac{P_r(x)}{D(x)} + \frac{P_g(x)}{1-D(x)} = 0 D(x)Pr(x)+1D(x)Pg(x)=0
可以得到最优化的判别器
D ∗ ( x ) = P r ( x ) P r ( x ) + P g ( x ) D^*(x) = \frac{P_r(x)}{P_r(x)+P_g(x)} D(x)=Pr(x)+Pg(x)Pr(x)
得到这个答案后,将最优判别器的表达式代入原始GAN的目标函数中,并进行简单的变换(变换主要是为了引入KL散度)
E x ∼ P r l o g ( P r ( x ) 0.5 ∗ [ P r ( x ) + P g ( x ) ] ) + E x ∼ P g l o g ( P g ( x ) 0.5 ∗ [ P r ( x ) + P g ( x ) ] ) − 2 l o g 2 E_{x\sim{P_r}}log(\frac{P_r(x)}{0.5*[P_r(x)+P_g(x)]}) + E_{x\sim{P_g}}log(\frac{P_g(x)}{0.5*[P_r(x)+P_g(x)]}) - 2log2 ExPrlog(0.5[Pr(x)+Pg(x)]Pr(x))+ExPglog(0.5[Pr(x)+Pg(x)]Pg(x))2log2
在上述表达式的情况下,我们已知KL散度和JS散度的表达式
K L ( P 1 ∣ ∣ P 2 ) = E x ∼ P 1 l o g ( P 1 P 2 ) KL(P_1||P_2) = E_{x\sim{P1}}log(\frac{P_1}{P_2}) KL(P1P2)=ExP1log(P2P1)
J S ( P 1 ∣ ∣ P 2 ) = 1 2 K L ( P 1 ∣ ∣ P 1 + P 2 2 ) + 1 2 K L ( P 2 ∣ ∣ P 1 + P 2 2 ) JS(P_1||P_2) = \frac{1}{2}KL(P_1||\frac{P_1+P_2}{2}) + \frac{1}{2}KL(P_2||\frac{P_1+P_2}{2}) JS(P1P2)=21KL(P12P1+P2)+21KL(P22P1+P2)
在这里借助于JS散度的表达式代入改造后的目标函数,可以得到
2 J S ( P r ∣ ∣ P g ) − 2 l o g 2 2JS(P_r||P_g) -2log2 2JS(PrPg)2log2
在这里就可以很清楚了,可以看到原始GAN,在最优判别器的情况下,训练generator实际上是在最优化真实分布 P r P_r Pr和生成分布 P g P_g Pg之间的JS散度,但是JS散度本身存在一个大问题,如果真实分布和生成分布之间没有重叠部分,我们来计算下JS散度
P 1 ( x ) = 0  且  P 2 ( x ) = 0 P_1(x) = 0 \text{ 且 } P_2(x)=0 P1(x)=0  P2(x)=0
P 1 ( x ) ≠ 0  且  P 2 ( x ) ≠ 0 P_1(x) \neq 0 \text{ 且 } P_2(x) \neq 0 P1(x)̸=0  P2(x)̸=0
P 1 ( x ) = 0  且  P 2 ≠ 0 P_1(x) = 0 \text{ 且 } P_2 \neq 0 P1(x)=0  P2̸=0
P 1 ( x ) ≠ 0  且  P 2 ( x ) = 0 P_1(x) \neq 0 \text{ 且 } P_2(x) =0 P1(x)̸=0  P2(x)=0
第一种情况对JS散度无贡献,第二种情况由于重叠部分不存在,因为也为0,第三种情况右边向 l o g ( P 2 0.5 ∗ ( P 2 + 0 ) ) = 0 log(\frac{P_2}{0.5*(P_2+0)})=0 log(0.5(P2+0)P2)=0,所以极限情况是log2,这样的情况就说明,当两个分布之间没有重合的时候JS散度始终不变,这意味着无法通过梯度变换反向更新网络参数,因为JS散度变化为0

术语的解释

P r P_r Pr P g P_g Pg的支撑集是高维空间中的地位流形时, P r P_r Pr P g P_g Pg重叠部分的测度为0的概率为1,这里有一些数学概念,下面给出解释

  1. 支撑集:就是函数非零部分子集,例如ReLU函数的支撑集就是 ( 0 , + ∞ ) (0,+\infty) (0,+)
  2. 流形:高维空间中的曲面与曲线,是三维空间的曲面曲线概念的推广
  3. 测度:高维空间中长度、面积、体积概念的推广,可理解为"超体积”
    P r P_r Pr P g P_g Pg的支撑集是高维空间中的低维流形时”,基本上是成立的。原因是GAN中的生成器一般是从某个低维(比如100维)的随机分布中采样出一个编码向量,再经过一个神经网络生成出一个高维样本(比如64x64的图片就有4096维)。当生成器的参数固定时,生成样本的概率分布虽然是定义在4096维的空间上,但它本身所有可能产生的变化已经被那个100维的随机分布限定了,其本质维度就是100,再考虑到神经网络带来的映射降维,最终可能比100还小,所以生成样本分布的支撑集就在4096维空间中构成一个最多100维的低维流形,“撑不满”整个高维空间。
    这里原始GAN难以训练的原因就清楚了
对于Generatro loss改进后仍然存在的问题

但是如果将生成器loss在进行更改,改为 E x ∼ P g [ − l o g D ( x ) ] E_{x\sim{P_g}}[-logD(x)] ExPg[logD(x)]形式,这种情况也是存在的,因此需要分析,将上文推导的最佳Discriminator表达式 D ∗ D^* D代入可得到,优化目标函数表达式为
2 J S ( P r ∣ ∣ P g ) − 2 l o g 2 2JS(P_r||P_g) -2log2 2JS(PrPg)2log2
最终可以将优化目标函数化为
K L ( P g ∣ ∣ P r ) − 2 J S ( P r ∣ ∣ P g ) + 2 l o g 2 + E x ∼ P r [ l o g D ∗ ( x ) ] KL(P_g||P_r) - 2JS(P_r||P_g) +2log2 + E_{x\sim{P_r}}[logD^*(x)] KL(PgPr)2JS(PrPg)+2log2+ExPr[logD(x)]
很显然,要确定最终的G只取决于前两项
但是这个优化目标有两个大问题,同时最小化生成分布与真实分布KL散度,但同时又要最大化两者JS分布,这个实际上是不可能的,因为JS和KL趋势一致,这就造成了训练时候梯度不稳定。另外KL散度也有局限,根据KL散度的计算公式, K L ( P g ∣ ∣ P r ) KL(P_g||P_r) KL(PgPr) K L ( P r ∣ ∣ P g ) KL(P_r||P_g) KL(PrPg),不对称,下面分析一下这种错误
P g ( x ) → 0  且  P r ( x ) → 1 , P g ( x ) l o g ( P g ( x ) P r ( x ) ) → 0  对KL贡献趋近于0 P_g(x) \rightarrow 0 \text{ 且 }P_r(x) \rightarrow 1, P_g(x)log(\frac{P_g(x)}{P_r(x)}) \rightarrow 0 \text{ 对KL贡献趋近于0} Pg(x)0  Pr(x)1,Pg(x)log(Pr(x)Pg(x))0 KL贡献趋近于0
P g ( x ) → 1  且  P r ( x ) → 0 , P g ( x ) l o g ( P g ( x ) P r ( x ) ) → + ∞ , 对KL贡献趋近无穷大 P_g(x) \rightarrow 1 \text{ 且 }P_r(x) \rightarrow 0 , P_g(x)log(\frac{P_g(x)}{P_r(x)}) \rightarrow +\infty, \text{对KL贡献趋近无穷大} Pg(x)1  Pr(x)0,Pg(x)log(Pr(x)Pg(x))+,KL贡献趋近无穷大
第一种情况对应着生成器没有生成真实样本,惩罚小、第二种情况对应的是生成器生成了不真实样本,惩罚巨大,这样实际上就上生成器生成了一些重复但是非真实的样本了,这就是大家所说的collapse mode

针对原始GAN的不足,最开始的修正方案

综上所述:原始GAN存在两点不足,第一点使用KL散度与JS散度不合理,第二点是generator的分布与真实分布很难有足够程度的重合 ,原先的做法是对生成样本和真实样本家噪声,这个的作用相当于将两个低维流形进行扩散,改变原有的分布,直观上来说可以理解为二维情况下加噪声会使得原本的标准高斯分布变胖,这样的结果就是强制使这两者有足够程度的重合部分,这样JS散度就会起作用,然后随着训练的进行,逐渐降低加入噪声的程度,使得这个低维流形更为还原为原来的情况,同时确保JS有用,但是这里有一个很大的问题:因为具体计算出的JS散度同加入的噪声有关,随着噪声的降低,因此也会带来loss本身的降低,这种情况下 P r P_r Pr P g P_g Pg这两个分布的接近所带来的loss就很难从整体的loss降低过程中区分出来,所以这个loss很难有什么作用。

引入Wasserstein距离的优越性质

Wasserstein距离称之为EM距离,定义如下:
W ( P r , P g ) = i n f r ∼ ∏ P r , P g E ( x , y ) [ ∣ ∣ x − y ∣ ∣ ] W(P_r,P_g) = \mathop{inf}\limits_{r\sim{\prod{P_r,P_g}}}E_(x,y)[||x-y||] W(Pr,Pg)=rPr,PginfE(x,y)[xy]
上式中 ∏ ( P r , P g ) \prod(P_r,P_g) (Pr,Pg) P r P_r Pr P g P_g Pg组合起来的所有可能的联合分布集合,对于这样的一个联合分布 r r r中进行抽样可以得到一个真实样本 x x x和生成样本 y y y,并计算出这样距离||x-y||,所以可以求得这样的期望值,在这些所有联合分布中的期望值取得一个下界,就称之为Wasserstein距离,直观来说就相对于在所有(x,y)的联合分布r中将 P g P_g Pg推到 P r P_r Pr,这种情况下求最小的
移动距离,所以Wasserstein距离也叫作推土机距离,这种散度相对于KL散度和JS散度的优点在于,即使是两个分布没有重叠部分,也能够反映两者的远近,并指导两者逐步靠近

从Wasserstein距离到WGAN

在从Wasserstein距离变化到可以使用的loss时,需要进行一些变换,WGAN的作者提出,借助于已有定理可以将其变换为
W ( P r , P g ) = 1 K s u p ∣ ∣ f ∣ ∣ L ≤ K E x ∼ P r [ f ( x ) ] − E x ∼ P g [ f ( x ) ] W(P_r,P_g) = \frac{1}{K}\mathop{sup}\limits_{||f||_L\leq{K}}E_{x\sim{P_r}}[f(x)] - E_{x\sim{P_g}}[f(x)] W(Pr,Pg)=K1fLKsupExPr[f(x)]ExPg[f(x)]
在证明过程中使用到了一个Lipschitz连续的概念,这一点在后续WGAN的重大改进中是一个重点改进方向,libschitz连续是指,一个函数在它的定义域内存在如下的约束
∣ f ( x 1 ) − f ( x 2 ) ∣ ≤ K ∣ x 1 − x 2 ∣ |f(x_1) - f(x_2)| \leq K|x_1-x_2| f(x1)f(x2)Kx1x2
并且 K ≥ 0 K \geq 0 K0,这个常数K称为Lipschitz常数,那么对于上面wasserstein距离的变换,认为里面的函数 f ( x ) f(x) f(x)可以使用神经网络来拟合,这里需要说明 s u p sup sup表示的是上确界,实际上经过推土机距离改造后的优化目标函数应该改写为
K ⋅ W ( P r , P g ) ≈ m a x w : ∣ f w ∣ L ≤ K E x ∼ P r ∣ f w ( x ) ∣ − E x ∼ P g [ f w ( x ) ] K\cdot{W(P_r,P_g)} \approx{\mathop{max}\limits_{w:|f_w|_{L\leq{K}}}}E_{x\sim{P_r}}|f_w(x)|-E_{x\sim{P_g}}[f_w(x)] KW(Pr,Pg)w:fwLKmaxExPrfw(x)ExPg[fw(x)]
然后为了满足libschitz约束,存在 ∣ ∣ f w ( x ) ∣ ∣ &lt; L ||f_w(x)||&lt;L fw(x)<L限制,这实际上说明了 ∂ f w ∂ x \frac{\partial{f_w}}{\partial{x}} xfw应该是一个有界的,这个界跟K有关,但却是一个有限的数
回到实际的网络设计上,在原始GAN的设计中discirminator面临的是一个二分类问题,因此最后一层经常使用sigmoid,但是WGAN中面临的是一个近似拟合Wasserstein距离的任务,属于回归任务,因此discriminator的最后一层不需要加上sigmoid;对于generator来说上述目标函数中的 P r P_r Pr相关项与生成器无关,则可以分别得到对于discriminator和generator的GAN loss
− E x ∼ P g [ f w ( x ) ]  (generator GAN loss) -E_{x\sim{P_g}}[f_w(x)] \text{ (generator GAN loss)} ExPg[fw(x)] (generator GAN loss)
E x ∼ P g [ f w ( x ) ] − E x ∼ P r [ f w ( x ) ]  discriminator GAN loss E_{x\sim{P_g}}[f_w(x)] - E_{x\sim{P_r}}[f_w(x)] \text{ discriminator GAN loss} ExPg[fw(x)]ExPr[fw(x)] discriminator GAN loss

从代码出发观察WGAN在训练中的处理

  1. 去除sigmoid,最后一层普遍使用的是conv2D,例如如下所示:
netd = nn.Sequential(
            nn.Conv2d(opt.nc,opt.ndf,4,2,1,bias=False),
            nn.LeakyReLU(0.2,inplace=True),
            
            nn.Conv2d(opt.ndf,opt.ndf*2,4,2,1,bias=False),
            nn.BatchNorm2d(opt.ndf*2),
            nn.LeakyReLU(0.2,inplace=True),
            
            nn.Conv2d(opt.ndf*2,opt.ndf*4,4,2,1,bias=False),
            nn.BatchNorm2d(opt.ndf*4),
            nn.LeakyReLU(0.2,inplace=True),
            
            nn.Conv2d(opt.ndf*4,opt.ndf*8,4,2,1,bias=False),
            nn.BatchNorm2d(opt.ndf*8),
            nn.LeakyReLU(0.2,inplace=True),
            
            nn.Conv2d(opt.ndf*8,1,4,1,0,bias=False),
            # Modification 1: remove sigmoid
            # nn.Sigmoid()
        )
  1. 不使用具有动量的优化方法,比如使用Adam,转而使用诸如RMSProp,SGD等方法,尤其是作者推荐使用RMSProp,因为该方法可以处理梯度不稳定的情况,如下所示:
optimizerD = RMSprop(netd.parameters(),lr=opt.lr ) 
optimizerG = RMSprop(netg.parameters(),lr=opt.lr )  
  1. 需要对discriminator的权重做修整限制以确保lipschitz连续约束,代码示例如下
 for p in netD.parameters():
    p.data.clamp_(clamp_lower, clamp_upper)

这里的clamp_lower和clamp_upper是文章中的约束范围,这里的取值是经验参数,有人推荐使用-0.01和0.01
4. 将BCEloss 改为非log的loss,按照文章的记载,通常会使用直接同1和-1做比较,代码示例如下

one=t.FloatTensor([1])
mone=-1*one
...
output=netd(input)
output.backward(one)
...
output2=netd(fake_pic)
output2.backward(mone)

参考内容

[1]. Ian J. Goodfellow, Generative Adversarial Nets ,2014
[2]. Martin Arjovsky Wasserstein GAN, 2017
[3]. https://zhuanlan.zhihu.com/p/25071913
[4]. https://github.com/chenyuntc/pytorch-GAN/blob/master/WGAN.ipynb
[5]. https://github.com/tjwei/GANotebooks
[6]. https://github.com/igul222/improved_wgan_training

WGAN(Wasserstein GAN)是GAN(Generative Adversarial Network)的一种改进模型,它通过使用Wasserstein距离替代JS散度来解决GAN中的训练不稳定问题,从而提高了生成器和判别器的训练效果。在这里,我将介绍如何使用PyTorch实现WGAN来生成动漫头像。 首先,我们需要导入必要的库: ```python import torch import torch.nn as nn import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision.datasets import ImageFolder import matplotlib.pyplot as plt import numpy as np ``` 接下来,我们定义一些超参数: ```python batch_size = 64 # 批次大小 n_epochs = 200 # 训练轮数 z_dim = 100 # 噪声维度 lr = 0.00005 # 学习率 clip_value = 0.01 # 截断值 n_critic = 5 # 判别器训练次数 ``` 然后,我们定义生成器和判别器的网络结构: ```python class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.fc = nn.Sequential( nn.Linear(z_dim, 4*4*512), nn.BatchNorm1d(4*4*512) ) self.conv = nn.Sequential( nn.ConvTranspose2d(512, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.ReLU(True), nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(True), nn.ConvTranspose2d(64, 3, 4, 2, 1), nn.Tanh() ) def forward(self, z): x = self.fc(z) x = x.view(-1, 512, 4, 4) x = self.conv(x) return x class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.conv = nn.Sequential( nn.Conv2d(3, 64, 4, 2, 1), nn.LeakyReLU(0.2), nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2), nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2), nn.Conv2d(256, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2) ) self.fc = nn.Linear(4*4*512, 1) def forward(self, x): x = self.conv(x) x = x.view(-1, 4*4*512) x = self.fc(x) return x ``` 接下来,我们定义WGAN模型: ```python class WGAN(object): def __init__(self): self.generator = Generator() self.discriminator = Discriminator() self.generator.cuda() self.discriminator.cuda() self.optimizer_g = torch.optim.RMSprop(self.generator.parameters(), lr=lr) self.optimizer_d = torch.optim.RMSprop(self.discriminator.parameters(), lr=lr) self.loss_fn = nn.MSELoss() def train(self, data_loader): total_step = len(data_loader) for epoch in range(n_epochs): for i, (images, _) in enumerate(data_loader): # 训练判别器 for j in range(n_critic): images = images.cuda() z = torch.randn(batch_size, z_dim).cuda() fake_images = self.generator(z) real_out = self.discriminator(images) fake_out = self.discriminator(fake_images.detach()) loss_d = -torch.mean(real_out) + torch.mean(fake_out) self.optimizer_d.zero_grad() loss_d.backward() self.optimizer_d.step() # 截断判别器的参数 for p in self.discriminator.parameters(): p.data.clamp_(-clip_value, clip_value) # 训练生成器 z = torch.randn(batch_size, z_dim).cuda() fake_images = self.generator(z) fake_out = self.discriminator(fake_images) loss_g = -torch.mean(fake_out) self.optimizer_g.zero_grad() loss_g.backward() self.optimizer_g.step() if (i+1) % 10 == 0: print ('Epoch [{}/{}], Step [{}/{}], Loss_D: {:.4f}, Loss_G: {:.4f}' .format(epoch+1, n_epochs, i+1, total_step, loss_d.item(), loss_g.item())) # 保存生成的图片 with torch.no_grad(): fake_images = self.generator(z) fake_images = fake_images.view(-1, 3, 64, 64) save_image(fake_images, 'generated_images-{}.png'.format(epoch+1)) ``` 最后,我们加载动漫头像数据集,并训练WGAN模型: ```python # 加载数据集 transform = transforms.Compose([ transforms.Resize(64), transforms.CenterCrop(64), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) dataset = ImageFolder('./data', transform) data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # 训练WGAN模型 wgan = WGAN() wgan.train(data_loader) ``` 训练完成后,我们可以使用生成器生成一些动漫头像: ```python # 加载训练好的生成器 generator = Generator() generator.load_state_dict(torch.load('./generator.pth')) # 生成动漫头像 z = torch.randn(64, z_dim).cuda() fake_images = generator(z) fake_images = fake_images.view(-1, 3, 64, 64) for i in range(64): plt.subplot(8, 8, i+1) plt.imshow((fake_images[i].cpu().detach().numpy().transpose(1, 2, 0)+1)/2) plt.axis('off') plt.show() ``` 至此,我们就完成了使用PyTorch实现WGAN自动生成动漫头像的代码
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值