十一、变分推断
1.背景
{频率角度,优化问题{回归{Model策略算法{解析解数值解SVMEM等等贝叶斯角度,积分问题{贝叶斯Infernece(求贝叶斯后验)P(θ∣x)=P(x∣θ)P(θ)P(x)贝叶斯决策(预测,最后还是求贝叶斯后验)P(x~∣x)=∫θP(x~,θ∣x)dθ=∫θP(x~∣θ)P(θ∣x)dθ=Eθ∣x[P(x~∣θ)] \begin{cases} 频率角度,优化问题 \begin{cases} 回归 \begin{cases} Model\\ 策略\\ 算法 \begin{cases} 解析解\\ 数值解 \end{cases} \end{cases}\\ SVM\\ EM\\ 等等 \end{cases}\\ 贝叶斯角度,积分问题 \begin{cases} 贝叶斯Infernece(求贝叶斯后验)\\ P(\theta \mid x)=\frac{P(x \mid \theta)P(\theta)}{P(x)}\\ 贝叶斯决策(预测,最后还是求贝叶斯后验)\\ P(\tilde{x} \mid x)=\int_{\theta}P(\tilde{x},\theta \mid x) d\theta=\int_{\theta}P(\tilde{x} \mid \theta)P(\theta \mid x)d\theta=E_{\theta \mid x}[P(\tilde{x} \mid \theta)] \end{cases} \end{cases}⎩⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎧频率角度,优化问题⎩⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎧回归⎩⎪⎪⎪⎨⎪⎪⎪⎧Model策略算法{解析解数值解SVMEM等等贝叶斯角度,积分问题⎩⎪⎪⎪⎨⎪⎪⎪⎧贝叶斯Infernece(求贝叶斯后验)P(θ∣x)=P(x)P(x∣θ)P(θ)贝叶斯决策(预测,最后还是求贝叶斯后验)P(x~∣x)=∫θP(x~,θ∣x)dθ=∫θP(x~∣θ)P(θ∣x)dθ=Eθ∣x[P(x~∣θ)]
Inference{精确推断(后验简单)近似推断/近似推断的期望(参数空间、隐变量非常复杂){确定性近似→VI随机近似→MCMC,MH,GibbsInference \begin{cases} 精确推断(后验简单)\\ 近似推断/近似推断的期望(参数空间、隐变量非常复杂)\\ \begin{cases} 确定性近似\to VI\\ 随机近似 \to MCMC,MH,Gibbs \end{cases} \end{cases}Inference⎩⎪⎪⎪⎨⎪⎪⎪⎧精确推断(后验简单)近似推断/近似推断的期望(参数空间、隐变量非常复杂){确定性近似→VI随机近似→MCMC,MH,Gibbs
2.公式推导
xxx:observed data
zzz:later variable + parameter
(x,z)(x,z)(x,z):complete data
ELBO + KL
logP(x)=L(q)+KL(q∣∣p)\log P(x)= L(q)+KL(q||p)logP(x)=L(q)+KL(q∣∣p)
q^(z)=argmaxq(z)L(q)→q^(z)≈p(z∣x)\hat q(z)=\arg \max_{q(z)} L(q) \to \hat q(z) \approx p(z \mid x)q^(z)=argmaxq(z)L(q)→q^(z)≈p(z∣x)
基于物理的平均场理论
q(z)=∏i=1Mqi(zi)q(z)=\prod_{i=1}^M q_i(z_i)q(z)=∏i=1Mqi(zi),计算时固定一维qj(zj)q_j(z_j)qj(zj)
L(q)=∫zq(z)logP(x,z)dz−∫zq(z)logq(z)dzL(q)=\int_z q(z) \log P(x,z)dz-\int_z q(z)\log q(z)dzL(q)=∫zq(z)logP(x,z)dz−∫zq(z)logq(z)dz
∫zq(z)logP(x,z)dz=∫z∏i=1Mqi(zi)logP(x,z)dz=∫zjqj(zj)dzj(∫zi∏iMqi(zi)logP(x,z)dzi)(i≠j)=∫zjqj(zj)dzjE∏iMqi(zi)[logP(x,z)]∫zq(z)logq(z)dz=∫z∏i=1Mqi(zi)log∏i=1Mqi(zi)dz=∫z∏i=1Mqi(zi)∑i=1Mlogqi(zi)dz=∫z∏i=1Mqi(zi)[logq1(z1)+logq2(z2)+⋯+logqM(zM)]dz=∑i=1M∫ziqi(zi)logqi(zi)dzi=∫zjqj(zj)logqj(zj)dzj+CL(q)=∫zjqj(zj)logp^(x,zj)qj(zj)dzj=−KL(qj∣∣p^(x,zj))≤0 \begin{aligned} \int_z q(z) \log P(x,z)dz &=\int_z \prod_{i=1}^M q_i(z_i) \log P(x,z)dz\\ &=\int_{z_j} q_j(z_j) dz_j\left ( \int_{z_i} \prod_{i}^M q_i(z_i) \log P(x,z) dz_i \right )(i \ne j)\\ &=\int_{z_j} q_j(z_j) dz_j E_{\prod_{i}^{M} q_i(z_i)}[\log P(x,z)]\\ \int_z q(z) \log q(z) dz &=\int_z \prod_{i=1}^{M}q_i(z_i) \log \prod_{i=1}^{M}q_i(z_i)dz\\ &=\int_z \prod_{i=1}^{M}q_i(z_i) \sum_{i=1}^{M} \log q_i(z_i)dz\\ &=\int_z \prod_{i=1}^{M}q_i(z_i) [ \log q_1(z_1)+\log q_2(z_2)+\cdots+\log q_M(z_M) ] dz\\ &=\sum_{i=1}^{M} \int_{z_i}q_i(z_i) \log q_i(z_i)dz_i\\ &=\int_{z_j}q_j(z_j) \log q_j(z_j)dz_j + C\\ L(q)&=\int_{z_j} q_j(z_j)\log \frac{\hat p(x,z_j)}{q_j(z_j)}dz_j\\ &=-KL(q_j||\hat p(x,z_j)) \le0 \end{aligned}∫zq(z)logP(x,z)dz∫zq(z)logq(z)dzL(q)=∫zi=1∏Mqi(zi)logP(x,z)dz=∫zjqj(zj)dzj(∫zii∏Mqi(zi)logP(x,z)dzi)(i=j)=∫zjqj(zj)dzjE∏iMqi(zi)[logP(x,z)]=∫zi=1∏Mqi(zi)logi=1∏Mqi(zi)dz=∫zi=1∏Mqi(zi)i=1∑Mlogqi(zi)dz=∫zi=1∏Mqi(zi)[logq1(z1)+logq2(z2)+⋯+logqM(zM)]dz=i=1∑M∫ziqi(zi)logqi(zi)dzi=∫zjqj(zj)logqj(zj)dzj+C=∫zjqj(zj)logqj(zj)p^(x,zj)dzj=−KL(qj∣∣p^(x,zj))≤0
用EM算法求解含隐变量的极大似然估计,极大似然估计是关于后验概率的函数,将不等式划等号q=p,一般p很复杂不易求得,最小化KL可以得到最优解q^\hat qq^,基于平均场理论,使用相互独立的∏i=1Mqi(zi)\prod_{i=1}^{M}q_i(z_i)∏i=1Mqi(zi)近似推断后验p
3.再回首
VI (mean field)→\to→Classical VI
Assumption:q(z)=∏i=1Mqi(zi)q(z)=\prod_{i=1}^{M}q_i(z_i)q(z)=∏i=1Mqi(zi)
logqj(zj)=E∏iqi(zi)[logPθ(x(i),z)]+C=∫q1⋯∫qj−1∫qj+1⋯∫qMq1⋯qj−1qj+1⋯qM[logPθ(x(i)),z]dq1⋯dqj−1dqj+1⋯dqM
\begin{aligned}
\log q_j(z_j)
&=E_{\prod_i q_i(z_i)}[\log P_{\theta}(x^{(i)},z)]+C\\
&=\int_{q_1}\cdots\int_{q_{j-1}} \int_{q_{j+1}} \cdots \int_{q_M}q_1 \cdots q_{j-1}q_{j+1} \cdots q_M [\log P_{\theta}(x^{(i)}),z]dq_1\cdots dq_{j-1} dq_{j+1} \cdots dq_M
\end{aligned}logqj(zj)=E∏iqi(zi)[logPθ(x(i),z)]+C=∫q1⋯∫qj−1∫qj+1⋯∫qMq1⋯qj−1qj+1⋯qM[logPθ(x(i)),z]dq1⋯dqj−1dqj+1⋯dqM
目标函数:
q^=argminqKL(q∣∣p)=argmaxqL(q)q^1(z1)=∫q2⋯∫qMq2⋯qM[logPθ(x(i)),z]dq2⋯dqMq^2(z2)=∫q^1⋯∫qMq^1⋯qM[logPθ(x(i)),z]dq^1⋯dqMq^M(zM)=∫q^1∫q^2⋯∫q^M−1q^1q^2⋯q^M−1[logPθ(x(i)),z]dq^1q^2⋯dq^M−1\begin{aligned}
&\hat q = \arg \min_q KL(q||p)=\arg \max_q L(q)\\
&\hat q_1(z_1) =\int_{q_2} \cdots \int_{q_M}q_2 \cdots q_M [\log P_{\theta}(x^{(i)}),z]dq_2 \cdots dq_M\\
&\hat q_2(z_2) =\int_{\hat q_1} \cdots \int_{q_M}\hat q_1 \cdots q_M [\log P_{\theta}(x^{(i)}),z]d\hat q_1 \cdots dq_M\\
&\hat q_M(z_M)=\int_{\hat q_1} \int_{\hat q_2} \cdots \int_{\hat q_{M-1}}\hat q_1 \hat q_2 \cdots \hat q_{M-1} [\log P_{\theta}(x^{(i)}),z]d\hat q_1\hat q_2 \cdots d\hat q_{M-1}\\
\end{aligned}q^=argqminKL(q∣∣p)=argqmaxL(q)q^1(z1)=∫q2⋯∫qMq2⋯qM[logPθ(x(i)),z]dq2⋯dqMq^2(z2)=∫q^1⋯∫qMq^1⋯qM[logPθ(x(i)),z]dq^1⋯dqMq^M(zM)=∫q^1∫q^2⋯∫q^M−1q^1q^2⋯q^M−1[logPθ(x(i)),z]dq^1q^2⋯dq^M−1
类似于坐标上升梯度上升,收敛终止
Classical VI存在的问题:
- 假设太强
- intractable(依然要求积分)
4.SGVI
随机梯度变分推断
不再求q(z)q(z)q(z)的具体值,假设q(z)q(z)q(z)服从某种分布,求这个分布的参数ϕ\phiϕ
BELO
L(ϕ)=Eqϕ(z)[logPθ(x(i),z)qϕ(z)]ϕ^=argmaxL(ϕ)∇ϕL(ϕ)=∇ϕEqϕ(z)[logPθ(x(i),z)−logqϕ(z)]=∇ϕ∫qϕ(z)[logPθ(x(i),z)−logqϕ(z)]dz=∫∇ϕqϕ(z)[logPθ(x(i),z)−logqϕ(z)]dz+∫qϕ(z)∇ϕ[logPθ(x(i),z)−logqϕ(z)]dz=∫qϕ(z)∇ϕlogqϕ(z)[logPθ(x(i),z)−logqϕ(z)]dz−∫∇ϕqϕ(z)dz=Eqϕ(z)[∇ϕlogqϕ(z)(logPθ(x(i),z)−logqϕ(z))]
\begin{aligned}
L(\phi)&=E_{q_{\phi}(z)} \left [ \log \frac{P_{\theta}(x^{(i)},z)}{q_{\phi}(z)}\right ]\\
\hat \phi &= \arg \max L(\phi)\\
\nabla_{\phi}L(\phi)
&=\nabla_{\phi}E_{q_{\phi}(z)} \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ]\\
&=\nabla_{\phi}\int q_{\phi}(z) \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] dz\\
&=\int \nabla_{\phi} q_{\phi}(z) \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] dz \\
& \quad +\int q_{\phi}(z) \nabla_{\phi} \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] dz\\
&=\int q_{\phi}(z) \nabla_{\phi} \log q_{\phi}(z) \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] dz-\int \nabla_{\phi} q_{\phi}(z)dz\\
&=E_{q_{\phi}(z)} \left [ \nabla_{\phi} \log q_{\phi}(z) \left ( \log P_{\theta}(x^{(i)},z) - \log q_{\phi}(z) \right ) \right ]
\end{aligned}L(ϕ)ϕ^∇ϕL(ϕ)=Eqϕ(z)[logqϕ(z)Pθ(x(i),z)]=argmaxL(ϕ)=∇ϕEqϕ(z)[logPθ(x(i),z)−logqϕ(z)]=∇ϕ∫qϕ(z)[logPθ(x(i),z)−logqϕ(z)]dz=∫∇ϕqϕ(z)[logPθ(x(i),z)−logqϕ(z)]dz+∫qϕ(z)∇ϕ[logPθ(x(i),z)−logqϕ(z)]dz=∫qϕ(z)∇ϕlogqϕ(z)[logPθ(x(i),z)−logqϕ(z)]dz−∫∇ϕqϕ(z)dz=Eqϕ(z)[∇ϕlogqϕ(z)(logPθ(x(i),z)−logqϕ(z))]
因此可以用MC,从qϕ(z)q_{\phi}(z)qϕ(z)中采样,根据大数定理,用均值近似期望EEE
z(l)∼qϕ(z),l=1,2,⋯ ,Lz^{(l)} \sim q_{\phi}(z),l=1,2,\cdots,Lz(l)∼qϕ(z),l=1,2,⋯,L
≈1L∑i=1L∇ϕlogqϕ(z(l))logPθ(x(i),z(l)−logqϕ(z(l)))\approx \frac{1}{L} \sum_{i=1}^{L} \nabla_{\phi} \log q_{\phi}(z^{(l)})\log P_{\theta}(x^{(i)},z^{(l)}-\log q_{\phi}(z^{(l)}))≈L1i=1∑L∇ϕlogqϕ(z(l))logPθ(x(i),z(l)−logqϕ(z(l)))
存在的问题:
在于这部分∇ϕlogqϕ(z)\nabla_{\phi} \log q_{\phi}(z)∇ϕlogqϕ(z),当采样到的值接近于0时,在对数log中变化很快(很敏感,方差很大),需要更多的样本,才能比较好的近似;
yon用期望近似qϕ(z)q_{\phi}(z)qϕ(z)的梯度,而我们的目标函数是ϕ^\hat \phiϕ^,因此误差是非常大的。
Reparameterization Trick 重参化技巧
∇ϕL(ϕ)=∇ϕEqϕ(z)[logPθ(x(i),z)−logqϕ(z)]\nabla_{\phi}L(\phi) =\nabla_{\phi}E_{q_{\phi}(z)} \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ]∇ϕL(ϕ)=∇ϕEqϕ(z)[logPθ(x(i),z)−logqϕ(z)]
期望是关于qϕ(z)q_{\phi}(z)qϕ(z),qϕ(z)q_{\phi}(z)qϕ(z)和ϕ\phiϕ有关系,函数也和ϕ\phiϕ有关系,导致复杂度很高。为简化问题,假设qϕ(z)q_{\phi}(z)qϕ(z)和ϕ\phiϕ没有关系,用一个确定的分布p(ε)p(\varepsilon)p(ε)替代qϕ(z)q_{\phi}(z)qϕ(z),就可以对直接对函数求导,不用对期望求导。z∼pϕ(z∣x)z \sim p_{\phi}(z \mid x)z∼pϕ(z∣x),引入重参化技巧把zzz和ϕ\phiϕ的关系解耦。
假设z=gϕ(ε,x(i)),ε∼p(ε)z=g_{\phi}(\varepsilon, x^{(i)}),\varepsilon \sim p(\varepsilon)z=gϕ(ε,x(i)),ε∼p(ε),zzz和ε\varepsilonε为映射关系,各自的积分为1,有如下关系:
∣pϕ(z∣x(i))dz∣=∣p(ε)dε∣\left | p_{\phi}(z \mid x^{(i)})dz \right | = \left | p(\varepsilon)d\varepsilon \right |∣∣∣pϕ(z∣x(i))dz∣∣∣=∣p(ε)dε∣
∇ϕL(ϕ)=∇ϕ∫[logPθ(x(i),z)−logqϕ(z)]qϕ(z)dz=∇ϕ∫[logPθ(x(i),z)−logqϕ(z)]q(ε)dε=∇ϕEp(ε)[logPθ(x(i),z)−logqϕ(z)]=Ep(ε)[∇ϕ(logPθ(x(i),z)−logqϕ(z))]=Ep(ε)[∇z(logPθ(x(i),z)−logqϕ(z∣x(i)))⋅∇ϕz]=Ep(ε)[∇z(logPθ(x(i),z)−logqϕ(z∣x(i)))⋅∇ϕgϕ(ε,x(i))]
\begin{aligned}
\nabla_{\phi}L(\phi)
&=\nabla_{\phi}\int \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] q_{\phi}(z) dz\\
&=\nabla_{\phi}\int \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] q(\varepsilon) d\varepsilon\\
&=\nabla_{\phi} E_{p(\varepsilon)} \left [ \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ] \\
&=E_{p(\varepsilon)} \left [ \nabla_{\phi} \left ( \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z) \right ) \right ]\\
&=E_{p(\varepsilon)} \left [ \nabla_{z} \left ( \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z \mid x^{(i)}) \right ) \cdot \nabla_{\phi}z \right ]\\
&=E_{p(\varepsilon)} \left [ \nabla_{z} \left ( \log P_{\theta}(x^{(i)},z)-\log q_{\phi}(z \mid x^{(i)}) \right ) \cdot \nabla_{\phi} g_{\phi}(\varepsilon, x^{(i)}) \right ]
\end{aligned}
∇ϕL(ϕ)=∇ϕ∫[logPθ(x(i),z)−logqϕ(z)]qϕ(z)dz=∇ϕ∫[logPθ(x(i),z)−logqϕ(z)]q(ε)dε=∇ϕEp(ε)[logPθ(x(i),z)−logqϕ(z)]=Ep(ε)[∇ϕ(logPθ(x(i),z)−logqϕ(z))]=Ep(ε)[∇z(logPθ(x(i),z)−logqϕ(z∣x(i)))⋅∇ϕz]=Ep(ε)[∇z(logPθ(x(i),z)−logqϕ(z∣x(i)))⋅∇ϕgϕ(ε,x(i))]