变分推断variational Inference

变分推断是一种处理贝叶斯推断问题的方法,尤其在复杂的参数空间和隐变量场景下。它通过近似后验概率来解决积分问题,包括经典变分推断(VI)和随机梯度变分推断(SGVI)。经典VI采用平均场理论假设,而SGVI则通过优化分布参数来逼近后验。重参数化技巧是SGVI中解决梯度估计方差问题的有效手段。

十一、变分推断

1.背景

{频率角度,优化问题{回归{Model策略算法{解析解数值解SVMEM等等贝叶斯角度,积分问题{贝叶斯Infernece(求贝叶斯后验)P(θ∣x)=P(x∣θ)P(θ)P(x)贝叶斯决策(预测,最后还是求贝叶斯后验)P(x~∣x)=∫θP(x~,θ∣x)dθ=∫θP(x~∣θ)P(θ∣x)dθ=Eθ∣x[P(x~∣θ)] \begin{cases} 频率角度,优化问题 \begin{cases} 回归 \begin{cases} Model\\ 策略\\ 算法 \begin{cases} 解析解\\ 数值解 \end{cases} \end{cases}\\ SVM\\ EM\\ 等等 \end{cases}\\ 贝叶斯角度,积分问题 \begin{cases} 贝叶斯Infernece(求贝叶斯后验)\\ P(\theta \mid x)=\frac{P(x \mid \theta)P(\theta)}{P(x)}\\ 贝叶斯决策(预测,最后还是求贝叶斯后验)\\ P(\tilde{x} \mid x)=\int_{\theta}P(\tilde{x},\theta \mid x) d\theta=\int_{\theta}P(\tilde{x} \mid \theta)P(\theta \mid x)d\theta=E_{\theta \mid x}[P(\tilde{x} \mid \theta)] \end{cases} \end{cases}Model{SVMEMInfernece()P(θx)=P(x)P(xθ)P(θ)()P(x~x)=θP(x~,θx)dθ=θP(x~θ)P(θx)dθ=Eθx[P(x~θ)]

Inference{精确推断(后验简单)近似推断/近似推断的期望(参数空间、隐变量非常复杂){确定性近似→VI随机近似→MCMC,MH,GibbsInference \begin{cases} 精确推断(后验简单)\\ 近似推断/近似推断的期望(参数空间、隐变量非常复杂)\\ \begin{cases} 确定性近似\to VI\\ 随机近似 \to MCMC,MH,Gibbs \end{cases} \end{cases}Inference/(){VIMCMC,MH,Gibbs

2.公式推导

xxx:observed data
zzz:later variable + parameter
(x,z)(x,z)(x,z):complete data

ELBO + KL
log⁡P(x)=L(q)+KL(q∣∣p)\log P(x)= L(q)+KL(q||p)logP(x)=L(q)+KL(qp)
q^(z)=arg⁡max⁡q(z)L(q)→q^(z)≈p(z∣x)\hat q(z)=\arg \max_{q(z)} L(q) \to \hat q(z) \approx p(z \mid x)q^(z)=argmaxq(z)L(q)q^(z)p(zx)

基于物理的平均场理论
q(z)=∏i=1Mqi(zi)q(z)=\prod_{i=1}^M q_i(z_i)q(z)=i=1Mqi(zi),计算时固定一维qj(zj)q_j(z_j)qj(zj)
L(q)=∫zq(z)log⁡P(x,z)dz−∫zq(z)log⁡q(z)dzL(q)=\int_z q(z) \log P(x,z)dz-\int_z q(z)\log q(z)dzL(q)=zq(z)logP(x,z)dzzq(z)logq(z)dz

∫zq(z)log⁡P(x,z)dz=∫z∏i=1Mqi(zi)log⁡P(x,z)dz=∫zjqj(zj)dzj(∫zi∏iMqi(zi)log⁡P(x,z)dzi)(i≠j)=∫zjqj(zj)dzjE∏iMqi(zi)[log⁡P(x,z)]∫zq(z)log⁡q(z)dz=∫z∏i=1Mqi(zi)log⁡∏i=1Mqi(zi)dz=∫z∏i=1Mqi(zi)∑i=1Mlog⁡qi(zi)dz=∫z∏i=1Mqi(zi)[log⁡q1(z1)+log⁡q2(z2)+⋯+log⁡qM(zM)]dz=∑i=1M∫ziqi(zi)log⁡qi(zi)dzi=∫zjqj(zj)log⁡qj(zj)dzj+CL(q)=∫zjqj(zj)log⁡p^(x,zj)qj(zj)dzj=−KL(qj∣∣p^(x,zj))≤0 \begin{aligned} \int_z q(z) \log P(x,z)dz &=\int_z \prod_{i=1}^M q_i(z_i) \log P(x,z)dz\\ &=\int_{z_j} q_j(z_j) dz_j\left ( \int_{z_i} \prod_{i}^M q_i(z_i) \log P(x,z) dz_i \right )(i \ne j)\\ &=\int_{z_j} q_j(z_j) dz_j E_{\prod_{i}^{M} q_i(z_i)}[\log P(x,z)]\\ \int_z q(z) \log q(z) dz &=\int_z \prod_{i=1}^{M}q_i(z_i) \log \prod_{i=1}^{M}q_i(z_i)dz\\ &=\int_z \prod_{i=1}^{M}q_i(z_i) \sum_{i=1}^{M} \log q_i(z_i)dz\\ &=\int_z \prod_{i=1}^{M}q_i(z_i) [ \log q_1(z_1)+\log q_2(z_2)+\cdots+\log q_M(z_M) ] dz\\ &=\sum_{i=1}^{M} \int_{z_i}q_i(z_i) \log q_i(z_i)dz_i\\ &=\int_{z_j}q_j(z_j) \log q_j(z_j)dz_j + C\\ L(q)&=\int_{z_j} q_j(z_j)\log \frac{\hat p(x,z_j)}{q_j(z_j)}dz_j\\ &=-KL(q_j||\hat p(x,z_j)) \le0 \end{aligned}zq(z)logP(x,z)dzzq(z)logq(z)dzL(q)=zi=1Mqi(zi)logP(x,z)dz=zjqj(zj)dzj(ziiMqi(zi)logP(x,z)dzi)(i=j)=zjqj(zj)dzjEiMqi(zi)[logP(x,z)]=zi=1Mqi(zi)logi=1Mqi(zi)dz=zi=1Mqi(zi)i=1Mlogqi(zi)dz=zi=1Mqi(zi)[logq1(z1)+logq2(z2)++logqM(zM)]dz=i=1Mziqi(zi)logqi(zi)dzi=zjqj(zj)logqj(zj)dzj+C=zjqj(zj)logqj(zj)p^(x,zj)dzj=KL(qjp^(x,zj))0

用EM算法求解含隐变量的极大似然估计,极大似然估计是关于后验概率的函数,将不等式划等号q=p,一般p很复杂不易求得,最小化KL可以得到最优解q^\hat qq^,基于平均场理论,使用相互独立的∏i=1Mqi(zi)\prod_{i=1}^{M}q_i(z_i)i=1Mqi(zi)近似推断后验p

3.再回首

VI (mean field)→\toClassical VI

Assumption:q(z)=∏i=1Mqi(zi)q(z)=\prod_{i=1}^{M}q_i(z_i)q(z)=i=1Mqi(zi)
log⁡qj(zj)=E∏iqi(zi)[log⁡Pθ(x(i),z)]+C=∫q1⋯∫qj−1∫qj+1⋯∫qMq1⋯qj−1qj+1⋯qM[log⁡Pθ(x(i)),z]dq1⋯dqj−1dqj+1⋯dqM \begin{aligned} \log q_j(z_j) &=E_{\prod_i q_i(z_i)}[\log P_{\theta}(x^{(i)},z)]+C\\ &=\int_{q_1}\cdots\int_{q_{j-1}} \int_{q_{j+1}} \cdots \int_{q_M}q_1 \cdots q_{j-1}q_{j+1} \cdots q_M [\log P_{\theta}(x^{(i)}),z]dq_1\cdots dq_{j-1} dq_{j+1} \cdots dq_M \end{aligned}logqj(zj)=Eiqi(zi)[logPθ(x(i),z)]+C=q1qj1qj+1qMq1qj1qj+1qM[logPθ(x(i)),z]dq1dqj1dqj+1dqM

目标函数:
q^=arg⁡min⁡qKL(q∣∣p)=arg⁡max⁡qL(q)q^1(z1)=∫q2⋯∫qMq2⋯qM[log⁡Pθ(x(i)),z]dq2⋯dqMq^2(z2)=∫q^1⋯∫qMq^1⋯qM[log⁡Pθ(x(i)),z]dq^1⋯dqMq^M(zM)=∫q^1∫q^2⋯∫q^M−1q^1q^2⋯q^M−1[log⁡Pθ(x(i)),z]dq^1q^2⋯dq^M−1\begin{aligned} &\hat q = \arg \min_q KL(q||p)=\arg \max_q L(q)\\ &\hat q_1(z_1) =\int_{q_2} \cdots \int_{q_M}q_2 \cdots q_M [\log P_{\theta}(x^{(i)}),z]dq_2 \cdots dq_M\\ &\hat q_2(z_2) =\int_{\hat q_1} \cdots \int_{q_M}\hat q_1 \cdots q_M [\log P_{\theta}(x^{(i)}),z]d\hat q_1 \cdots dq_M\\ &\hat q_M(z_M)=\int_{\hat q_1} \int_{\hat q_2} \cdots \int_{\hat q_{M-1}}\hat q_1 \hat q_2 \cdots \hat q_{M-1} [\log P_{\theta}(x^{(i)}),z]d\hat q_1\hat q_2 \cdots d\hat q_{M-1}\\ \end{aligned}q^=argqminKL(qp)=argqmaxL(q)q^1(z1)=q2qMq2qM[logPθ(x(i)),z]dq2dqMq^2(z2)=q^1qMq^1qM[logPθ(x(i)),z]dq^1dqMq^M(zM)=q^1q^2q^M1q^1q^2q^M1[logPθ(x(i)),z]dq^1q^2dq^M1

类似于坐标上升梯度上升,收敛终止

Classical VI存在的问题:

  • 假设太强
  • intractable(依然要求积分)

4.SGVI

随机梯度变分推断

不再求q(z)q(z)q(z)的具体值,假设q(z)q(z)q(z)服从某种分布,求这个分布的参数ϕ\phiϕ

BELO
L(ϕ)=Eqϕ(z)[log⁡Pθ(x(i),z)qϕ(z)]ϕ^=arg⁡max⁡L(ϕ)∇ϕL(ϕ)=∇ϕEqϕ(z)[log⁡Pθ(x(i),z)−log⁡qϕ(z)]=∇ϕ∫qϕ(z)[log⁡Pθ(x(i),z)−log⁡qϕ(z)]dz=∫∇ϕqϕ(z)[log⁡Pθ(x(i),z)−log⁡qϕ(z)]dz+∫qϕ(z)∇ϕ[log⁡Pθ(x(i),z)−log⁡qϕ(z)]dz=∫qϕ(z)∇ϕlog⁡qϕ(z)[log⁡Pθ(x(i),z)−log⁡qϕ(z)]dz−∫∇ϕqϕ(z)dz=Eqϕ(z)[∇ϕlog⁡qϕ(z)(log⁡Pθ(x(i),z)−log⁡qϕ(z))] \begin{aligned} L(\phi)&=E_{q_{\phi}(z)} \left [ \log \frac{P_{\theta}(x^{(i)},z)}{q_{\phi}(z)}\right ]\\ \hat \phi &= \arg \max L(\phi)\\ \nabla_{\phi}L(\phi) &=\nabla_{\phi}E_{q_{\phi}(z)} \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ]\\ &=\nabla_{\phi}\int q_{\phi}(z) \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] dz\\ &=\int \nabla_{\phi} q_{\phi}(z) \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] dz \\ & \quad +\int q_{\phi}(z) \nabla_{\phi} \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] dz\\ &=\int q_{\phi}(z) \nabla_{\phi} \log q_{\phi}(z) \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] dz-\int \nabla_{\phi} q_{\phi}(z)dz\\ &=E_{q_{\phi}(z)} \left [ \nabla_{\phi} \log q_{\phi}(z) \left ( \log P_{\theta}(x^{(i)},z) - \log q_{\phi}(z) \right ) \right ] \end{aligned}L(ϕ)ϕ^ϕL(ϕ)=Eqϕ(z)[logqϕ(z)Pθ(x(i),z)]=argmaxL(ϕ)=ϕEqϕ(z)[logPθ(x(i),z)logqϕ(z)]=ϕqϕ(z)[logPθ(x(i),z)logqϕ(z)]dz=ϕqϕ(z)[logPθ(x(i),z)logqϕ(z)]dz+qϕ(z)ϕ[logPθ(x(i),z)logqϕ(z)]dz=qϕ(z)ϕlogqϕ(z)[logPθ(x(i),z)logqϕ(z)]dzϕqϕ(z)dz=Eqϕ(z)[ϕlogqϕ(z)(logPθ(x(i),z)logqϕ(z))]
因此可以用MC,从qϕ(z)q_{\phi}(z)qϕ(z)中采样,根据大数定理,用均值近似期望EEE

z(l)∼qϕ(z),l=1,2,⋯ ,Lz^{(l)} \sim q_{\phi}(z),l=1,2,\cdots,Lz(l)qϕ(z),l=1,2,,L
≈1L∑i=1L∇ϕlog⁡qϕ(z(l))log⁡Pθ(x(i),z(l)−log⁡qϕ(z(l)))\approx \frac{1}{L} \sum_{i=1}^{L} \nabla_{\phi} \log q_{\phi}(z^{(l)})\log P_{\theta}(x^{(i)},z^{(l)}-\log q_{\phi}(z^{(l)}))L1i=1Lϕlogqϕ(z(l))logPθ(x(i),z(l)logqϕ(z(l)))

存在的问题
  在于这部分∇ϕlog⁡qϕ(z)\nabla_{\phi} \log q_{\phi}(z)ϕlogqϕ(z),当采样到的值接近于0时,在对数log中变化很快(很敏感,方差很大),需要更多的样本,才能比较好的近似;
yon用期望近似qϕ(z)q_{\phi}(z)qϕ(z)的梯度,而我们的目标函数是ϕ^\hat \phiϕ^,因此误差是非常大的。

Reparameterization Trick 重参化技巧

∇ϕL(ϕ)=∇ϕEqϕ(z)[log⁡Pθ(x(i),z)−log⁡qϕ(z)]\nabla_{\phi}L(\phi) =\nabla_{\phi}E_{q_{\phi}(z)} \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ]ϕL(ϕ)=ϕEqϕ(z)[logPθ(x(i),z)logqϕ(z)]
  期望是关于qϕ(z)q_{\phi}(z)qϕ(z)qϕ(z)q_{\phi}(z)qϕ(z)ϕ\phiϕ有关系,函数也和ϕ\phiϕ有关系,导致复杂度很高。为简化问题,假设qϕ(z)q_{\phi}(z)qϕ(z)ϕ\phiϕ没有关系,用一个确定的分布p(ε)p(\varepsilon)p(ε)替代qϕ(z)q_{\phi}(z)qϕ(z),就可以对直接对函数求导,不用对期望求导。z∼pϕ(z∣x)z \sim p_{\phi}(z \mid x)zpϕ(zx),引入重参化技巧把zzzϕ\phiϕ的关系解耦。

假设z=gϕ(ε,x(i)),ε∼p(ε)z=g_{\phi}(\varepsilon, x^{(i)}),\varepsilon \sim p(\varepsilon)z=gϕ(ε,x(i)),εp(ε)zzzε\varepsilonε为映射关系,各自的积分为1,有如下关系:
∣pϕ(z∣x(i))dz∣=∣p(ε)dε∣\left | p_{\phi}(z \mid x^{(i)})dz \right | = \left | p(\varepsilon)d\varepsilon \right |pϕ(zx(i))dz=p(ε)dε
∇ϕL(ϕ)=∇ϕ∫[log⁡Pθ(x(i),z)−log⁡qϕ(z)]qϕ(z)dz=∇ϕ∫[log⁡Pθ(x(i),z)−log⁡qϕ(z)]q(ε)dε=∇ϕEp(ε)[log⁡Pθ(x(i),z)−log⁡qϕ(z)]=Ep(ε)[∇ϕ(log⁡Pθ(x(i),z)−log⁡qϕ(z))]=Ep(ε)[∇z(log⁡Pθ(x(i),z)−log⁡qϕ(z∣x(i)))⋅∇ϕz]=Ep(ε)[∇z(log⁡Pθ(x(i),z)−log⁡qϕ(z∣x(i)))⋅∇ϕgϕ(ε,x(i))] \begin{aligned} \nabla_{\phi}L(\phi) &=\nabla_{\phi}\int \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] q_{\phi}(z) dz\\ &=\nabla_{\phi}\int \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] q(\varepsilon) d\varepsilon\\ &=\nabla_{\phi} E_{p(\varepsilon)} \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] \\ &=E_{p(\varepsilon)} \left [ \nabla_{\phi} \left ( \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ) \right ]\\ &=E_{p(\varepsilon)} \left [ \nabla_{z} \left ( \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z \mid x^{(i)}) \right ) \cdot \nabla_{\phi}z \right ]\\ &=E_{p(\varepsilon)} \left [ \nabla_{z} \left ( \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z \mid x^{(i)}) \right ) \cdot \nabla_{\phi} g_{\phi}(\varepsilon, x^{(i)}) \right ] \end{aligned} ϕL(ϕ)=ϕ[logPθ(x(i),z)logqϕ(z)]qϕ(z)dz=ϕ[logPθ(x(i),z)logqϕ(z)]q(ε)dε=ϕEp(ε)[logPθ(x(i),z)logqϕ(z)]=Ep(ε)[ϕ(logPθ(x(i),z)logqϕ(z))]=Ep(ε)[z(logPθ(x(i),z)logqϕ(zx(i)))ϕz]=Ep(ε)[z(logPθ(x(i),z)logqϕ(zx(i)))ϕgϕ(ε,x(i))]

### 回答1: 变分推断variational inference)是一种用于在概率模型中近似推断潜在变量的方法。在概率模型中,我们通常有观测数据和潜在变量两个部分。我们希望通过观测数据集来估计潜在变量的后验分布。然而,由于计算复杂度的限制,我们无法直接计算后验分布。 变分推断通过近似后验分布为一个简化的分布来解决这个问题。它会选择一个与真实后验分布相似的分布族,然后通过最小化这个分布与真实后验分布之间的差异来得到一个最佳的近似分布。这个问题可以转化为一个最优化问题,通常使用变分推断的一个常用方法是最大化证据下界(evidence lower bound,ELBO)来近似后验分布。 变分推断的一个重要特点是可以处理大规模和复杂的概率模型。由于近似分布是通过简化的分布族来表示的,而不是直接计算后验分布,所以它可以减少计算复杂度。此外,变分推断还可以通过引入额外的约束或假设来进一步简化近似分布,提高计算效率。 然而,变分推断也有一些缺点。因为近似分布是通过简化的分布族来表示的,所以它会引入一定的偏差。此外,变分推断的结果依赖于所选择的分布族,如果分布族选择不合适,可能会导致较差的近似结果。 总之,变分推断是一种用于近似计算概率模型中后验分布的方法,通过选择一个与真实后验分布相似的分布族,并最小化与真实后验分布之间的差异来得到一个最佳的近似分布。它具有处理大规模和复杂模型的能力,但也有一些局限性。 ### 回答2: 转变分推断variational inference)是一种用于近似求解复杂概率模型的方法。它的核心思想是将复杂的后验分布近似为一个简单的分布,通过最小化这两个分布之间的差异来求解模型的参数。 变分推断通过引入一个简单分布(称为变分分布)来近似复杂的后验分布。这个简单分布通常属于某个已知分布族,例如高斯分布或指数分布。变分推断通过最小化变分分布和真实后验分布之间的差异,来找到最优的参数。 为了实现这一点,变分推断使用了KL散度(Kullback-Leibler divergence)这一概念。KL散度是用来衡量两个概率分布之间的差异的指标。通过最小化变分分布与真实后验分布之间的KL散度,我们可以找到一个最优的变分分布来近似真实后验分布。 变分推断的步骤通常包括以下几个步骤: 1. 定义变分分布:选择一个简单的分布族作为变分分布,例如高斯分布。 2. 定义目标函数:根据KL散度的定义,定义一个目标函数,通常包括模型的似然函数和变分分布的熵。 3. 最优化:使用数值方法(例如梯度下降法)最小化目标函数,找到最优的变分参数。 4. 近似求解:通过最优的变分参数,得到近似的后验分布,并用于模型的推断或预测。 变分推断的优点是可以通过选择合适的变分分布,来控制近似精度和计算复杂度之间的平衡。它可以应用于各种概率模型和机器学习任务,例如潜在变量模型、深度学习和无监督学习等。 总而言之,转变分推断是一种用于近似求解复杂概率模型的方法,通过近似后验分布来求解模型的参数。它通过最小化变分分布与真实后验分布之间的差异来实现近似求解。这个方法可以应用于各种概率模型和机器学习任务,具有广泛的应用价值。 ### 回答3: 变分推断Variational Inference)是一种用于概率模型中的近似推断方法。它的目标是通过近似的方式来近似估计概率分布中的某些未知参数或隐变量。 在概率模型中,我们通常希望得到后验概率分布,即给定观测数据的情况下,未知参数或隐变量的概率分布。然而,由于计算复杂性的原因,我们往往无法直接计算后验分布。 变分推断通过引入一个称为变分分布的简化分布,将原问题转化为一个优化问题。具体来说,我们假设变分分布属于某个分布族,并通过优化一个目标函数,使得变分分布尽可能接近真实的后验分布。 目标函数通常使用卡尔贝克-勒勒散度(Kullback-Leibler divergence)来度量变分分布与真实后验分布之间的差异。通过最小化这个目标函数,我们可以找到最优的近似分布。在这个优化问题中,我们通常将问题转化为一个变分推断问题,其中我们需要优化关于变分分布的参数。 变分推断的一个优点是可以应用于各种类型的概率模型,无论是具有连续随机变量还是离散变量。此外,变分推断还可以解决复杂的后验推断问题,如变分贝叶斯方法和逐步变分推断等。 然而,变分推断也存在一些限制。例如,它通常要求选择一个合适的变分分布族,并且该族必须在计算上可以处理。此外,变分推断还可能导致近似误差,因为我们将问题简化为一个优化问题,可能会导致对真实后验分布的一些信息丢失。 总而言之,变分推断是一种强大的近似推断方法,可以用于概率模型中的参数和隐变量的估计。它通过引入变分分布来近似计算复杂的后验概率分布,从而转化为一个优化问题。然而,需要注意选择合适的变分分布族和可能的近似误差。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值