MGO-GAN:利用正交向量缓解GAN训练时模式崩溃的问题

MGO-GAN:利用正交向量缓解GAN训练时模式崩溃的问题

Tackling mode collapse in multi-generator GANs with orthogonal vectors

论文地址:https://www.sciencedirect.com/science/article/abs/pii/S0031320320304490#section-cited-by

本文转发自知乎作者: 记忆的迷谷

本文《Tackling mode collapse in multi-generator GANs with orthogonal vectors》(简称MGO-GAN)使用多个生成器,并利用正交向量约束生成器间生成图像相似的情况,从而缓解GAN训练时模式崩溃的问题。

简介

众所周知,GAN在训练过程中存在模式崩溃的问题,训练鲁棒性强的GAN也是一大难题。当前提出了很多方式缓解此类问题,但在训练过程中容易出现负梯度等问题。本文提出了一种新的方式来训练GAN,通过一组生成器、一个编码器和一个鉴别器来克服模式崩溃。采用正交向量策略引导多个生成器以互补的方式学习不同的信息。具体来说,这些生成器生成的合成数据被输入编码器以获得特征向量。在任意两个特征向量之间计算正交值,从而反映了向量之间的相关性。这种相关性表明了生成器如何学习不同的信息。正交值越低,生成器学习的信息越不同。

什么是模式崩溃:

GAN主要由两部分组成:生成器和鉴别器,两者博弈达到纳什均衡的状态。然而由于训练时不稳定的状态,生成器往往倾向于生成一组相似的图片来欺骗鉴别器。(因为如果生成其他不同图片,多样性虽然会提升,但是惩罚可能更高,因此生成器可能会变得“保守”,干脆只生成相似的图片)。

具体来说,当我们使用神经网络将真实数据从原始数据空间映射到潜在空间时,映射的数据点通常位于不同大小的区域上。每次神经网络倾向从潜在空间的较大区域中选取数据点,忽略其他地方的数据点。此外,当生成的数据分布(PG)和真实数据分布(PR)有一个可忽略的重叠区域时,JS散度最大,导致梯度消失。在这种情况下,生成器显然无法捕获数据的所有模式,进一步加剧了模式崩溃。


我们以MINIST为例,如上图所示,模式崩溃发生时,生成的数字缺少2,5,6,尽管生成的数字质量较高,但是多样性不足,也就是发生了所谓的模式崩溃的情形。
总结一下,模式崩溃的三个原因如下:
(1)空间随机映射的策略,导致模型有更大的概率选择更大区域中的数据。
(2)噪声到原始空间的映射不是满射的
(3)鉴别器只判断是否是真实数据,不判断数据的多样性如何。

正交向量怎样约束生成图片的相似度?

首先阐述一下正交向量的定义:内积为0的两个或多个向量,即向量相互垂直。如果两个非零向量是正交的,它们必须是线性独立的。O(α,β)可以忠实地反映α和β之间的相关性。O(α,β)的值越小,两个向量包含的信息就越不同。

正交值表示生成器学习不同信息的方式。

两个矢量(α和β)由编码器产生,编码器将两个不同生成器(Gi和Gj)的输出作为其输入。O(α,β)的值越小,两个生成器获得的信息就越多。如果这个值比较大,说明生成图片相似,导致更大的loss,就会被惩罚。

正交值可以集成到GAN的训练中

为了保证在GAN的训练中可以最小化正交值,通过反向传播最小化正交值O(α,β)以及最小化生成器器损失,并将正交值O(α,β)与JS散度相结合联合更新生成器的参数。

理论依据:

假设输入的噪声几乎没有重叠区域的时候,虽然2JS-2log2为0,但是即使只有极少量重叠的时候,两个向量之间的余弦值也会不为0,所以梯度不会消失。

实现方式:

首先明确MGO-GAN需要解决的两个问题:
1)生成器如何学习不同的信息
2) MGO-GAN需要收敛到一个平衡点。

因此,MGO-GAN首先需要K个生成器共同构建一个函数,将随机噪声z映射到数据空间X。将Z上的先验分布定义为pz(Z),其中生成器从中采样噪声。原始数据遵循特定的分布,我们称之为pr(x)。即最大化鉴别器D,最小化生成器G1:K。


损失函数如上图所示。我们需要最大化真实数据通过判别器的期望,最小化判别器识别出假图的概率,同时还要让两两生成器之间的正交向量特征值尽可能小。(其中K表示发生器的数量,λ表示系数。E表示编码器,用于提取由K个生成器生成的生成数据的特征向量,以便将特征向量映射到相同的空间中,以计算相应的正交值。)

对D和G的优化方式如上图所示:对D来说,就是最大化真实数据通过判别器的概率,最大化判别器识别出K个生成器生成假图的平均概率。对G来说,就是最小化判别器识别出假图的概率,同时让两两生成器之间正交向量值越小越好。

上述是MGO-GAN的训练步骤:首先定义了正交值的计算公式:Cal-Orthogonal。然后首先采样m个噪声,然后采样m个真实数据。接着通过随机梯度下降策略更新D的网络参数,然后采样利用采样的m个噪声,然后将生成图片和正交向量两部分更新生成器,促使其生成多样化的图片。

MGO-GAN的结构如上图所示:p(z)代表输入的噪声,G1~Gk分别代表K个生成器,不同生成器生成的图片会被送入编码器中,E的作用是编码成特征向量,然后两两之间计算输出向量的正交值,作为损失函数的一部分,继续更新生成器G1~Gk。此外,G1~Gk需要送入判别器D中,进行极大极小的博弈。

实验结果:

在MINIST数据集上,相比于其他的GAN,MGO-GAN生成的图像更多样,且生成图片的每个类别的数量也相对更均衡一些。

同时,在FID指标上,MGO-GAN的FID值也是更低的。


在CelebaA数据集上,MGO-GAN的生成效果也会更真实多样一些,同时在下面的FID指标上更有优势。

更多前言论文/模型解析,尽在极链AI云

邀请新用户/新注册完成学生认证更可领取超大使用券

参与模型复现,奖励更丰厚哦~

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值