过去几年发表于各大 AI 顶会论文提出的 400 多种算法中,公开算法代码的仅占 6%,其中三分之一的论文作者分享了测试数据,约 54% 的分享包含“伪代码”。这是今年 AAAI 会议上一个严峻的报告。 人工智能这个蓬勃发展的领域正面临着实验重现的危机,就像实验重现问题过去十年来一直困扰着心理学、医学以及其他领域一样。最根本的问题是研究人员通常不共享他们的源代码。
可验证的知识是科学的基础,它事关理解。随着人工智能领域的发展,打破不可复现性将是必要的。为此,PaperWeekly 联手百度 PaddlePaddle 共同发起了本次论文有奖复现,我们希望和来自学界、工业界的研究者一起接力,为 AI 行业带来良性循环。
作者丨文永亮
学校丨华南理工大学
研究方向丨目标检测、图像生成
笔者这次选择复现的是 Least Squares Generative Adversarial Networks,也就是 LSGANs。
近几年来 GAN 是十分火热的,由 Goodfellow 在 14 年发表论文 Generative Adversarial Nets [1] 开山之作以来,生成式对抗网络一直都备受机器学习领域的关注,这种两人零和博弈的思想十分有趣,充分体现了数学的美感。从 GAN 到 WGAN [2] 的优化,再到本文介绍的 LSGANs,再到最近很火的 BigGAN [3],可以说生成式对抗网络的魅力无穷,而且它的用处也是非常奇妙,如今还被用在例如无负样本的情况下如何训练分类器,例如 AnoGAN [4]。
LSGANs 这篇经典的论文主要工作是把交叉熵损失函数换做了最小二乘损失函数,这样做作者认为改善了传统 GAN 的两个问题,即传统 GAN 生成的图片质量不高,而且训练过程十分不稳定。
LSGANs 试图使用不同的距离度量来构建一个更加稳定而且收敛更快的,生成质量高的对抗网络。但是我看过 WGAN 的论文之后分析这一损失函数,其实并不符合 WGAN 作者的分析。在下面我会详细分析一下为什么 LSGANs 其实并没有那么好用。
论文复现代码:
http://aistudio.baidu.com/aistudio/#/projectdetail/25767
LSGANs的优点
我们知道传统 GAN 生成的图片质量不高,传统的 GANs 使用的是交叉熵损失(sigmoid cross entropy)作为判别器的损失函数。
在这里说一下我对交叉熵的理解,有两个分布,分别是真实分布 p 和非真实分布 q。
信息熵是,就是按照真实分布 p 这样的样本空间表达能力强度的相反值,信息熵越大,不确定性越大,表达能力越弱,我们记作 H(p)。 交叉熵就是,可以理解为按照不真实分布 q 这样的样本空间表达能力强度的相反值,记作 H(p,q)。
KL 散度就是 D(p||q) = H(p,q) - H(p),它表示的是两个分布的差异,因为真实分布 p 的信息熵固定,所以一般由交叉熵来决定,所以这就是为什么传统 GAN 会采用交叉熵的缘故,论文也证明了 GAN 损失函数与 KL 散度的关系。
我们知道交叉熵一般都是拿来做逻辑分类的,而像最小二乘这种一般会用在线性回归中,这里为什么会用最小二乘作为损失函数的评判呢?
使用交叉熵虽然会让我们分类正确,但是这样会导致那些在决策边界被分类为真的、但是仍然远离真实数据的假样本(即生成器生成的样本)不会继续迭代,因为它已经成功欺骗了判别器,更新生成器的时候就会发生梯度弥散的问题。
论文指出最小二乘损失函数会对处于判别成真的那些远离决策边界的样本进行惩罚,把远离决策边界的假样本拖进决策边界,从而提高生成图片的质量。作者用下图详细表达了这一说法:
我们知道传统 GAN 的训练过程十分不稳定,这很大程度上是因为它的目标函数,尤其是在最小化目标函数时可能发生梯度弥散,使其很难再去更新生成器。而论文指出 LSGANs 可以解决这个问题,因为 LSGANs 会惩罚那些远离决策边界的样本,这些样本的梯度是梯度下降的决定方向。
论文指出因为传统 GAN 辨别器 D 使用的是 sigmoid 函数,并且由于 sigmoid 函数饱和得十分迅速,所以即使是十分小的数据点 x,该函数也会迅速忽略样本 x 到决策边界 w 的距离。这就意味着 sigmoid 函数本质上不会惩罚远离决策边界的样本,并且也说明我们满足于将 x 标注正确,因此辨别器 D 的梯度就会很快地下降到 0。
我们可以认为,交叉熵并不关心距离,而是仅仅关注于是否正确分类。正如论文作者在下图中所指出的那样,(a)图看到交叉熵损失很容易就达到饱和状态,而(b)图最小二乘损失只在一点达到饱和,作者认为这样训练会更加稳定。
LSGANs的损失函数
传统 GAN 的损失函数:
LSGANs 的损失函数:
其中 G 为生成器(Generator),D 为判别器(Discriminator),z 为噪音,它可以服从归一化或者高斯分布,为真实数据 x 服从的概率分布,