©PaperWeekly 原创 · 作者 | 苏剑林
单位 | 追一科技
研究方向 | NLP、神经网络
Wasserstein 距离(下面简称“W距离”),是基于最优传输思想来度量两个概率分布差异程度的距离函数,笔者之前在《从Wasserstein距离、对偶理论到WGAN》等文章中也做过介绍。
对于很多读者来说,第一次听说 W 距离,是因为 2017 年出世的 WGAN [1],它开创了从最优传输视角来理解 GAN 的新分支,也提高了最优传输理论在机器学习中的地位。很长一段时间以来,GAN [2] 都是生成模型领域的“主力军”,直到最近这两年扩散模型异军突起,GAN 的风头才有所下降,但其本身仍不失为一个强大的生成模型。
从形式上来看,扩散模型和 GAN 差异很明显,所以其研究一直都相对独立。不过,去年底的一篇论文《Score-based Generative Modeling Secretly Minimizes the Wasserstein Distance》[3] 打破了这个隔阂:它证明了扩散模型的得分匹配损失可以写成 W 距离的上界形式。这意味着在某种程度上,最小化扩散模型的损失函数,实则跟 WGAN 一样,都是在最小化两个分布的 W 距离。
结论分析
具体来说,原论文的结果,是针对《生成扩散模型漫谈:一般框架之SDE篇》中介绍的 SDE 式扩散模型的,其核心结论是不等式(其中 是 的非负函数,具体含义我们后来再详细介绍)
那么怎样理解这个不等式呢?首先,扩散模型可以理解为 SDE 从 到 的一个运动过程,最右边的 是 时刻的随机采样分布, 通常就是标准正态分布,而实际应用中一般都有 ,所以 ,原论文之所以显式写出它,只是为了从理论上给出最一般的结果。
接着,左边的 ,是从 采样的随机点出发,经反向 SDE
求解得到的 时刻的值的分布,它实际上就是要生成的数据分布;而 ,则是从 采样的随机点出发,经过 SDE
求解得到的 时刻的值的分布,其中 是 的神经网络近似,所以 实际就是扩散模型生成的数据分布。因此, 的含义就是数据分布与生成分布的 W 距离。
最后,剩下的积分项,其关键部分是
这也正好是扩散模型的“得分匹配”损失。所以,当我们用得分匹配损失去训练扩散模型的时候,其实也间接地最小化了数据分布与生成分布的 距离。跟 WGAN 不同的是,WGAN 优化的 距离是 而这里是 。
注:准确来说,式(4)还不是扩散模型的损失函数,扩散模型的损失函数应该是“条件得分匹配”,它跟得分匹配的关系是:
最后的结果才是扩散模型的损失函数“条件得分匹配”。第一个等号是因为恒等式 ,第二个不等号则是因为平方平均不等式的推广或者詹森不等式,第三个等号则是贝叶斯公式了。也就是说,条件得分匹配是得分匹配的上界