【机器学习系列】变分推断第二讲:基于Mean Field的变分推断解法


作者:CHEONG

公众号:AI机器学习与知识图谱

研究方向:自然语言处理与知识图谱

阅读本文之前,首先注意以下两点:

1. 机器学习系列文章常含有大量公式推导证明,为了更好理解,文章在最开始会给出本文的重要结论,方便最快速度理解本文核心。需要进一步了解推导细节可继续往后看。

2. 文中含有大量公式,若读者需要获取含公式原稿Word文档,可关注公众号【AI机器学习与知识图谱】后回复:变分推断第二讲,可添加微信号【17865190919】进学习交流群,加好友时备注来自CSDN。原创不易,转载请告知并注明出处!

本文将先对变分推断所要解决的问题进行分析,然后给出基于Mean Field的变分推断解法。


一、本文结论

结论1: 变分推断的主要思想:在给定数据集 X X X下,问题是求后验概率 p p p,简单情况下后验概率 p p p可直接通过贝叶斯公式推导求出,但有些情况无法直接求解。因此变分推断想法是先假设另一个简单的概率分布 q q q,如高斯分布,通过优化 p p p q q q之间距离最小化,让概率分布 q q q逼近 p p p,这样就可以用概率分布 q q q近似表示后验概率 p p p

结论2: 基于Mean Field的变分推断方法主要是假设将隐变量 z z z分成M个相互独立的部分 z = ( z 1 , z 2 , . . . , z M ) z=(z_1,z_2,...,z_M) z=(z1,z2,...,zM) ,当求 q j ( z j ) q_j(z_j) qj(zj)时固定剩下M-1个部分。

结论3: 基于Mean Field的变分推断方法存在的两个问题:(1)假设将 z = ( z 1 , z 2 , . . . , z M ) z=(z_1,z_2,...,z_M) z=(z1,z2,...,zM)分成M个相互独立的部分,然后固定其他依次求得 q j ( z j ) q_j(z_j) qj(zj)。这个假设太强烈,在一些问题是无法分成相互独立的各个部分;(2)最后求出来的 q j ( z j ) q_j(z_j) qj(zj)仍然需要进行求积分,在一些问题中,仍然可能是Intractable,无法求解的。


二、问题分析

观测数据Observed Data: X X X

隐变量Latent Variable: Z Z Z

完整数据Complete Data: ( X , Z ) (X, Z) (X,Z)

目的: 求数据的后验概率 p ( z ∣ x ) p(z|x) p(zx),下面先给出变分推断的分析思路

在这里插入图片描述

首先由简单的联合概率分布的分解式引出问题,如下公式所示:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bnvN45Oj-1617958057057)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image023.png)]

通过两边加log变形为:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-wD6jQnPH-1617958057060)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image025.png)]

为了近似求解后验概率 p ( z ∣ x ) p(z|x) p(zx),我们需要先引入另一个分布 q ( z ) q(z) q(z),整合进上面公式中:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-OgrdLuv6-1617958057065)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image027.png)]

接下来分别将上式的左边和右边部分对 q ( z ) q(z) q(z)进行积分:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-A47Hl0wR-1617958057071)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image031.png)]

其中

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-37r1SigK-1617958057077)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image033.png)]

所以左边在积分后仍然是 l o g p ( x ) logp(x) logp(x),接下来对右边部分进行积分:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-9F0pehv7-1617958057080)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image037.png)]

其中前半部分是Evidence Lower Bound,简称为 E L B O ELBO ELBO

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Y9TMxn46-1617958057085)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image039.png)]

后半部分是概率分布 p p p q q q的相对熵:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-T4nemFit-1617958057087)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image045.png)]

因此有:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-KbtHW55Z-1617958057091)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image047.png)]

因为当数据给定的情况下,左边 l o g p ( x ) logp(x) logp(x)是定值,即 E L B O + K L ( q ∣ ∣ p ) ELBO+KL(q||p) ELBO+KL(qp)是一个定值,而其中 K L ( q ∣ ∣ p ) KL(q||p) KL(qp)是大于等于0的,且 K L ( q ∣ ∣ p ) KL(q||p) KL(qp)越小代表概率分布 p p p q q q就越接近,也就是我们要优化的目标,但 K L ( q ∣ ∣ p ) KL(q||p) KL(qp)中包含后验概率不好直接优化最小,但因为 E L B O + K L ( q ∣ ∣ p ) ELBO+KL(q||p) ELBO+KL(qp)是定值,所以我们可以优化让 E L B O ELBO ELBO部分最大, K L ( q ∣ ∣ p ) KL(q||p) KL(qp)相对就越小,这样便可以用概率分布 q q q来代替 p p p了。


三、公式推导

通过上一小节的描述已经明确了变分推断需要优化的目标,总结为如下公式:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SJmEbkXc-1617958057095)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image055.png)]

下面通过公式推导求解是的 E L B O ELBO ELBO最大的后验概率 q ( z ) q(z) q(z)的值,使用基于Mean Field的变分推断的解法求解后验概率分布 p ( z ∣ x ) p(z|x) p(zx)

先假设 z = ( z 1 , z 2 , . . . , z M ) z=(z_1,z_2,...,z_M) z=(z1,z2,...,zM),并且这M份之间是相互独立的,则有:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-iPQcUKRe-1617958057099)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image061.png)]

接下来对 E L B O ELBO ELBO项进行展开,并将 q ( z ) q(z) q(z)的值代入:

在这里插入图片描述

下面为了简便,先做一下变量假设:

在这里插入图片描述

在推导 A A A B B B前,先固定 z = ( z 1 , . . . , z j − 1 , z j + 1 . . . , z M ) z=(z_1,...,z_{j-1}, z_{j+1}...,z_M) z=(z1,...,zj1,zj+1...,zM),先 z j z_j zj,接下来先推导 A A A

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4kDY5r1z-1617958057135)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image081.png)]

其中有:

在这里插入图片描述

因此可以得出 A A A的值如下:

在这里插入图片描述

接下来推导 B B B

在这里插入图片描述

其中有:

在这里插入图片描述

因此得出了 B B B的值:

在这里插入图片描述

因为固定了 z = ( z 1 , . . . , z j − 1 , z j + 1 . . . , z M ) z=(z_1,...,z_{j-1}, z_{j+1}...,z_M) z=(z1,...,zj1,zj+1...,zM),只求未知量 z j z_j zj,所以:

在这里插入图片描述

其中 C C C是常量,至此有:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-tpDxZftt-1617958057172)(file:///C:/Users/zl_sd/AppData/Local/Temp/msohtmlclip1/01/clip_image099.png)]

因此当KL取0时, E L B O ELBO ELBO能达到最大值,所以这里求出 q j ( z j ) q_j(z_j) qj(zj)

在这里插入图片描述

其他的 q 1 ( z 1 ) , q 2 ( z 2 ) , , . . . , q M ( z M ) q_1(z_1),q_2(z_2),,...,q_M(z_M) q1(z1),q2(z2),,...,qM(zM)求解方法相同。这样求出了 q ∗ ( z ) q^{*}(z) q(z)求等价于求出了后验概率 p ( z ∣ x ) p(z|x) p(zx)


正如文章开头结论所说,基于Mean Field的变分推断方法存在的两个问题,下一节变分推断将介绍另一种解法:基于随机梯度上升SGD的变分推断推导方案:

1、假设将 z = ( z 1 , z 2 , . . . , z M ) z=(z_1,z_2,...,z_M) z=(z1,z2,...,zM) 分成M个相互独立的部分,然后固定其他依次求得 q j ( z j ) q_j(z_j) qj(zj)。这个假设太强烈,在一些问题是无法分成相互独立的各个部分;

2、最后求出来的 q j ( z j ) q_j(z_j) qj(zj)仍然是求积分,在一些问题中,仍然可能是Intractable,无法求解的。

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: 推断(variational inference)是一种用于在概率模型中近似推断潜在量的方法。在概率模型中,我们通常有观测数据和潜在量两个部分。我们希望通过观测数据集来估计潜在量的后验分布。然而,由于计算复杂度的限制,我们无法直接计算后验分布。 推断通过近似后验分布为一个简化的分布来解决这个问题。它会选择一个与真实后验分布相似的分布族,然后通过最小化这个分布与真实后验分布之间的差异来得到一个最佳的近似分布。这个问题可以转化为一个最优化问题,通常使用推断的一个常用方法是最大化证据下界(evidence lower bound,ELBO)来近似后验分布。 推断的一个重要特点是可以处理大规模和复杂的概率模型。由于近似分布是通过简化的分布族来表示的,而不是直接计算后验分布,所以它可以减少计算复杂度。此外,推断还可以通过引入额外的约束或假设来进一步简化近似分布,提高计算效率。 然而,推断也有一些缺点。因为近似分布是通过简化的分布族来表示的,所以它会引入一定的偏差。此外,推断的结果依赖于所选择的分布族,如果分布族选择不合适,可能会导致较差的近似结果。 总之,推断是一种用于近似计算概率模型中后验分布的方法,通过选择一个与真实后验分布相似的分布族,并最小化与真实后验分布之间的差异来得到一个最佳的近似分布。它具有处理大规模和复杂模型的能力,但也有一些局限性。 ### 回答2: 转推断(variational inference)是一种用于近似求解复杂概率模型的方法。它的核心思想是将复杂的后验分布近似为一个简单的分布,通过最小化这两个分布之间的差异来求解模型的参数。 推断通过引入一个简单分布(称为分分布)来近似复杂的后验分布。这个简单分布通常属于某个已知分布族,例如高斯分布或指数分布。推断通过最小化分分布和真实后验分布之间的差异,来找到最优的参数。 为了实现这一点,推断使用了KL散度(Kullback-Leibler divergence)这一概念。KL散度是用来衡量两个概率分布之间的差异的指标。通过最小化分分布与真实后验分布之间的KL散度,我们可以找到一个最优的分分布来近似真实后验分布。 推断的步骤通常包括以下几个步骤: 1. 定义分分布:选择一个简单的分布族作为分分布,例如高斯分布。 2. 定义目标函数:根据KL散度的定义,定义一个目标函数,通常包括模型的似然函数和分分布的熵。 3. 最优化:使用数值方法(例如梯度下降法)最小化目标函数,找到最优的分参数。 4. 近似求解:通过最优的分参数,得到近似的后验分布,并用于模型的推断或预测。 推断的优点是可以通过选择合适的分分布,来控制近似精度和计算复杂度之间的平衡。它可以应用于各种概率模型和机器学习任务,例如潜在量模型、深度学习和无监督学习等。 总而言之,转推断是一种用于近似求解复杂概率模型的方法,通过近似后验分布来求解模型的参数。它通过最小化分分布与真实后验分布之间的差异来实现近似求解。这个方法可以应用于各种概率模型和机器学习任务,具有广泛的应用价值。 ### 回答3: 推断(Variational Inference)是一种用于概率模型中的近似推断方法。它的目标是通过近似的方式来近似估计概率分布中的某些未知参数或隐量。 在概率模型中,我们通常希望得到后验概率分布,即给定观测数据的情况下,未知参数或隐量的概率分布。然而,由于计算复杂性的原因,我们往往无法直接计算后验分布。 推断通过引入一个称为分分布的简化分布,将原问题转化为一个优化问题。具体来说,我们假设分分布属于某个分布族,并通过优化一个目标函数,使得分分布尽可能接近真实的后验分布。 目标函数通常使用卡尔贝克-勒勒散度(Kullback-Leibler divergence)来度量分分布与真实后验分布之间的差异。通过最小化这个目标函数,我们可以找到最优的近似分布。在这个优化问题中,我们通常将问题转化为一个推断问题,其中我们需要优化关于分分布的参数。 推断的一个优点是可以应用于各种类型的概率模型,无论是具有连续随机量还是离散量。此外,推断还可以解决复杂的后验推断问题,如分贝叶斯方法和逐步推断等。 然而,推断也存在一些限制。例如,它通常要求选择一个合适的分分布族,并且该族必须在计算上可以处理。此外,推断还可能导致近似误差,因为我们将问题简化为一个优化问题,可能会导致对真实后验分布的一些信息丢失。 总而言之,推断是一种强大的近似推断方法,可以用于概率模型中的参数和隐量的估计。它通过引入分分布来近似计算复杂的后验概率分布,从而转化为一个优化问题。然而,需要注意选择合适的分分布族和可能的近似误差。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值