PyTorch 实战:计算 Wasserstein 距离
2019-09-23 18:42:56
This blog is copied from: https://mp.weixin.qq.com/s/nTUKYNxdiPK3xdOoSXvTJQ
最优传输理论及 Wasserstein 距离是很多读者都希望了解的基础,本文主要通过简单案例展示了它们的基本思想,并通过 PyTorch 介绍如何实战 W 距离。
机器学习中的许多问题都涉及到令两个分布尽可能接近的思想,例如在 GAN 中令生成器分布接近判别器分布就能伪造出逼真的图像。但是 KL 散度等分布的度量方法有很多局限性,本文则介绍了 Wasserstein 距离及 Sinkhorn 迭代方法,它们 GAN 及众多任务上都展示了杰出的性能。
在简单的情况下,我们假设从未知数据分布 p(x) 中观测到一些随机变量 x(例如,猫的图片),我们想要找到一个模型 q(x|θ)(例如一个神经网络)能作为 p(x) 的一个很好的近似。如果 p 和 q 的分布很相近,那么就表明我们的模型已经学习到如何识别猫。
因为 KL 散度可以度量两个分布的距离,所以只需要最小化 KL(q‖p) 就可以了。可以证明,最小化 KL(q‖p) 等价于最小化一个负对数似然,这样的做法在我们训练一个分类器时很常见。例如,对于变分自编码器来说,我们希望后验分布能够接近于某种先验分布,这也是我们通过最小化它们之间的 KL 散度来实现的。
尽管 KL 散度有很广泛的应用,在某些情况下,KL 散度则会失效。不妨考虑一下如下图所示的离散分布:
KL 散度假设这两个分布共享相同的支撑集(也就是说,它们被定义在同一个点集上)。因此,我们不能为上面的例子计算 KL 散度。由于这一个限制和其他计算方面的因素促使研究人员寻找一种更适合于计算两个分布之间差异的方法。
在本文中,作者将:
-
简单介绍最优传输问题
-
将 Sinkhorn 迭代描述为对解求近似
-
使用 PyTorch 计算 Sinkhorn 距离
-
描述用于计算 mini-batch 之间的距离的对该实现的扩展
移动概率质量函数
我们不妨把离散的概率分布想象成空间中分散的点的质量。我们可以观测这些带质量的点从一个分布移动到另一个分布需要做多少功,如下图所示:
接着,我们可以定义另一个度量标准,用以衡量移动做所有点所需要做的功。要想将这个直观的概念形式化定义下来,首先,我们可以通过引入一个耦合矩阵 P(coupling