Generative Adversarial Networks(WGAN、SAGAN、BigGAN)

在这里插入图片描述
此篇博文继续整理GAN的衍生版本。

Wasserstein Generative Adversarial Networks(WGAN)
GAN 在基于梯度下降训练时存在梯度消失的问题,特别是当真实样本和生成样本之间差距并不大,而且甚至近乎没有差距时, 其目标函数的 Jensen-Shannon散度将会是一个常数,这将导致想要优化的目标函数不连续。而且判别器训练的梯度很难把控更使其收敛不稳定。判别器训练得太好,生成器的梯度消失,loss 难以下降;判别器训练的不好,生成器的梯度不准确,loss 四处乱跑。所以为了解决训练不稳定的这些问题, Arjovsky 等提出了 WGAN,主要是提出了一种新的损失函数解决该问题。

首先原GAN的损失函数是以 Jensen-Shannon散度为度量:
L = min ⁡ G max ⁡ D E x ∼ p data ( x ) [ log ⁡ D ( x ) ] + E z ∼ p generated ( z ) [ 1 − log ⁡ D ( G ( z ) ) ] L=\min_{G}\max_{D}\mathbb{E}_{x\sim p{\text{data}}(x)}[\log{D(x)}] + \mathbb{E}_{z\sim p{\text{generated}}(z)}[1 - \log{D(G(z))}] L=GminDmaxExpdata(x)[logD(x)]+Ezpgenerated(z)[1logD(G(z))]

JS 散度是用于衡量两个分布之间的不同,它的值越大,就代表了两个分布相差越大。具体来说就是分别对Q和P的分布,算了两次KL散度,公式如下:
K L ( A ∣ ∣ B ) = ∫ − ∞ ∞ a ( x ) log ⁡ a ( x ) b ( x ) d x KL(A||B) = \int_{-\infty}^{\infty}a(x)\log{\frac{a(x)}{b(x)}} dx KL(AB)=a(x)logb(x)a(x)dx JS ( P ∣ ∣ Q ) = KL ( P ∣ ∣ P + Q 2 ) + KL ( Q ∣ ∣ P + Q 2 ) \text{JS}(P||Q) = \text{KL}(P || \frac{P+Q}{2}) + \text{KL}(Q || \frac{P+Q}{2}) JS(PQ)=KL(P2P+Q)+KL(Q2P+Q) L = 2 J S ( P ∣ ∣ Q ) − 2 l o g 2 L=2JS(P || Q)-2log2 L=2JS(PQ)2log2
带入原式子可以发现,当两个分布完全不同时,JS的值会保持为 2log2 的常量值。而函数为常量即代表了它的梯度为零,这意味着生成器啥也没学到。
在这里插入图片描述
而对比之下WGAN所采用的新损失函数是:
min ⁡ G max ⁡ D E x ∼ p data ( x ) [ D ( x ) ] − E z ∼ p generated ( z ) [ D ( G ( z ) ) ] \min_{G}\max_{D} \mathbb{E}_{x\sim p{\text{data}}(x)}[D(x)] - \mathbb{E}_{z \sim p{\text{generated}}(z)}[D(G(z))] GminDmaxExpdata(x)[D(x)]Ezpgenerated(z)[D(G(z))]

其替代距离度量JS的是 1-Wasserstein 距离,也被称为称为地球移动距离,即以两个沙坑之间土堆的变化过程来解释两个分布,其在移动土堆的成本用每粒沙的移动距离来计算:
E M ( P r , P θ ) = inf ⁡ γ ∈ Π , ∑ x , y ∥ x − y ∥ γ ( x , y ) = inf ⁡ γ ∈ Π   E ( x , y ) ∼ γ ∥ x − y ∥ \mathrm{EM}(P_r, P_\theta) = \inf_{\gamma \in \Pi} , \sum\limits_{x,y} \Vert x - y \Vert \gamma (x,y) = \inf_{\gamma \in \Pi} \ \mathbb{E}_{(x,y) \sim \gamma} \Vert x - y \Vert EM(Pr,Pθ)=γΠinf,x,yxyγ(x,y)=γΠinf E(x,y)γxy
其中 inf 是中位数(最小值),x 和 y 则是是两个土堆分布上的点,γ是指移动时的最佳方法。
E M ( P r , P θ ) = sup ⁡ ∥ f ∥ L ≤ 1   E x ∼ P r f ( x ) − E x ∼ P θ f ( x ) \mathrm{EM}(P_r, P_\theta) = \sup_{\lVert f \lVert_{L \leq 1}} \ \mathbb{E}{x \sim P_r} f(x) - \mathbb{E}{x \sim P_\theta} f(x) EM(Pr,Pθ)=fL1sup ExPrf(x)ExPθf(x)
不过WGAN 并没有改变 GAN 模型的结构,只是在优化方法上进行了改进。(其实它的损失函数与DRAGAN很像,只是WGAN是在全样本空间施加梯度惩罚,而DRAGAN只在训练样本附近施加梯度惩罚。)而且虽然看起来比较复杂,实际上在实践中需要:

  • 去掉判别器的最后一层的sigmoid
  • 对判别器使用梯度裁剪,将其取值限制在[-k,k]
  • 使用 RMSProp 或 SGD 并以较低的学习率进行优化

代码大部分没有什么不同:

#loss函数
def wasserstein_loss(self, y_true, y_pred):
        return K.mean(y_true * y_pred)
        
#在编译阶段进行设置
self.critic.compile(loss=self.wasserstein_loss,
            optimizer=optimizer,
            metrics=['accuracy'])

同时,作者后来又出了一个升级版WGAN-GP,把梯度裁剪变成梯度惩罚。
code:https://github.com/eriklindernoren/Keras-GAN
paper:https://arxiv.org/abs/1701.07875v3

在这里插入图片描述
Self-Attention Generative Adversarial Networks (SAGAN)
传统的CNN网络只能捕获局部的空间信息,视野有限,故大多数效果好的GANs生成的图像都是单个类或很少的类,而在学习多类图像时往往存在困难。SAGAN使用自注意力机制,尝试去学习图像的全局关系以解决这一问题,模型架构如下图。

在这里插入图片描述
对CNN网络卷积完后的图像分别用三个1x1的卷积核进行处理,以减少图像的通道数。然后对这三个向量进行Self-attention,即在Attention的计算过程中 K=V=Q,以更好的捕获内部结构关系。具体来说如果f 和 g各有32个卷积核,而h具有256个,处理后的特征将会有不同的关注点,通过自注意力使其自我关注和计算后,能够形成对图片的“全局观”。获取距离较远的相关区域的信息,能很好的提升生成图像的清晰度。

#主要区别在自注意力
class Self_Attn(nn.Module):
    def __init__(self,in_dim,activation):
        super(Self_Attn,self).__init__()
        self.chanel_in = in_dim
        self.activation = activation
        
        #Q、K、V。Q和K输出通道相同,V是其8倍
        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax  = nn.Softmax(dim=-1)
     
     #具体计算步骤
    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize,C,width ,height = x.size()
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
        energy =  torch.bmm(proj_query,proj_key) # transpose check
        attention = self.softmax(energy) # BX (N) X (N) 
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N

        out = torch.bmm(proj_value,attention.permute(0,2,1) )
        out = out.view(m_batchsize,C,width,height)
        
        out = self.gamma*out + x
        return out,attention

code:https://github.com/heykeetae/Self-Attention-GAN
paper:https://arxiv.org/abs/1805.08318v1

在这里插入图片描述
BigGAN
天下武功,以量取胜。
BigGAN在SAGAN的基础上,用了一组TPU,证明了大量数据和复杂模型与高超调参技术,是完全可以显着提高GANS性能的,对比起步分数52.52,BigGAN 的得分是152.8。。。

code:https://github.com/huggingface/pytorch-pretrained-BigGAN
paper:https://arxiv.org/abs/1809.11096v2

GAN系列总结
关于损失函数的度量问题,LSGAN,f-GAN
模式崩溃问题,DRAGAN,MADGAN
收敛不稳定问题,WGAN,WGAN-GP
对偶学习,源域迁移问题,CycleGAN
条件控制问题,cGAN,IcGAN
图像质量(分辨率)提升,SAGAN,CoGAN
生成的准确性与多样性,IS,FID


《Generative Adversarial Networks: A Survey and Taxonomy》
《Generative Adversarial Networks - The Story So Far》

  • 3
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值