今天第一篇为WGAN理论深入介绍。
在GAN的相关研究如火如荼甚至可以说是泛滥的今天,一篇新鲜出炉的arXiv论文Wassertein GAN 却在 Reddit 的 Machine Learning 频道火了,连Goodfellow都在帖子里和大家热烈讨论,这篇论文究竟有什么了不得的地方呢?
要知道自从2014年Ian Goodfellow提出以来,GAN就存在着训练困难、生成器和判别器的loss无法指示训练进程、生成样本缺乏多样性等问题。从那时起,很多论文都在尝试解决,但是效果不尽人意,比如最有名的一个改进DCGAN依靠的是对判别器和生成器的架构进行实验枚举,最终找到一组比较好的网络架构设置,但是实际上是治标不治本,没有彻底解决问题。而今天的主角Wasserstein GAN(下面简称WGAN)成功地做到了以下爆炸性的几点:
-
彻底解决GAN训练不稳定的问题,不再需要小心平衡生成器和判别器的训练程度
-
基本解决了collapse mode的问题,确保了生成样本的多样性
-
训练过程中终于有一个像交叉熵、准确率这样的数值来指示训练的进程,这个数值越小代表GAN训练得越好,代表生成器产生的图像质量越高(如题图所示)
-
以上一切好处不需要精心设计的网络架构,最简单的多层全连接网络就可以做到
那以上好处来自哪里?这就是令人拍案叫绝的部分了——实际上作者整整花了两篇论文,在第一篇《Towards Principled Methods for Training Generative Adversarial Networks》里面推了一堆公式定理,从理论上分析了原始GAN的问题所在,从而针对性地给出了改进要点;在这第二篇《Wassertein GAN》里面,又再从这个改进点出发推了一堆公式定理,最终给出了改进的算法实现流程,而改进后相比原始GAN的算法实现流程却只改了四点:
-
判别器最后一层去掉sigmoid
-
生成器和判别器的loss不取log
-
每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c
-
不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行
算法截图如下:
改动是如此简单,效果却惊人地好,以至于Reddit上不少人在感叹:就这样?没有别的了? 太简单了吧!这些反应让我想起了一个颇有年头的鸡汤段子,说是一个工程师在电机外壳上用粉笔划了一条线排除了故障,要价一万美元——画一条线,1美元;知道在哪画线,9999美元。上面这四点改进就是作者Martin Arjovsky划的简简单单四条线,对于工程实现便已足够,但是知道在哪划线,背后却是精巧的数学分析,而这也是本文想要整理的内容。
本文内容分为五个部分:
-
原始GAN究竟出了什么问题?(此部分较长)
-
WGAN之前的一个过渡解决方案
-
Wasserstein距离的优越性质
-
从Wasserstein距离到WGAN
-
总结
理解原文的很多公式定理需要对测度论、 拓扑学等数学知识有所掌握,本文会从直观的角度对每一个重要公式进行解读,有时通过一些低维的例子帮助读者理解数学背后的思想,所以不免会失于严谨,如有引喻不当之处,欢迎在评论中指出。
以下简称 Wassertein GAN 为“WGAN本作”,简称 Towards Principled Methods for Training Generative Adversarial Networks 为“WGAN前作”。
WGAN 源码实现:https://github.com/martinarjovsky/WassersteinGAN
第一部分:原始GAN究竟出了什么问题?
回顾一下,原始GAN中判别器要最小化如下损失函数,尽可能把真实样本分为正例,生成样本分为负例:
其中 Pr 是真实样本分布,Pg 是由生成器产生的样本分布。对于生成器,Goodfellow 一开始提出来一个损失函数,后来又提出了一个改进的损失函数,分别是
后者在 WGAN 两篇论文中称为“the - log D alternative”或“the - log D trick”。WGAN 前作分别分析了这两种形式的原始GAN各自的问题所在,下面分别说明。
第一种原始GAN形式的问题
一句话概括:判别器越好,生成器梯度消失越严重。WGAN 前作从两个角度进行了论证,第一个角度是从生成器的等价损失函数切入的。
首先从公式1可以得到,在生成器 G 固定参数时最优的判别器 D 应该是什么。对于一个具体的样本,它可能来自真实分布也可能来自生成分布,它对公式1损失函数的贡献是
令其关于D(x)的导数为0,得
化简得最优判别器为: