1. 变分推断
1.1 什么是变分推断?(直观理解)
想象你在玩一个侦探游戏:你有一堆线索(观测数据 x x x),想找出幕后真相(隐变量 z z z)。在贝叶斯统计中,真相的分布是后验分布 p ( z ∣ x ) p(z|x) p(z∣x),但这个分布通常很复杂,计算起来像解一个超级复杂的拼图。变分推断就像是用一个简单的拼图(一个容易计算的分布 q ( z ) q(z) q(z))去“模仿”那个复杂的拼图,通过调整简单拼图的形状,让它尽可能接近真实的拼图。
核心目标:用一个简单的分布 q ( z ) q(z) q(z) 去近似复杂的后验分布 p ( z ∣ x ) p(z|x) p(z∣x),通过优化让两者差异最小。
1.2 为什么需要变分推断?
在贝叶斯推断中,我们想知道隐变量 z z z(比如文档的主题、数据的潜在类别)在给定数据 x x x 下的概率分布:
p ( z ∣ x ) = p ( x , z ) p ( x ) p(z|x) = \frac{p(x, z)}{p(x)} p(z∣x)=p(x)p(x,z)
- p ( x , z ) p(x, z) p(x,z) 是联合分布,描述数据和隐变量如何一起生成。
- p ( x ) = ∫ p ( x , z ) d z p(x) = \int p(x, z) dz p(x)=∫p(x,z)dz 是边际似然,相当于把所有可能的 z z z 都考虑进去,计算数据的总概率。
问题在于, p ( x ) = ∫ p ( x , z ) d z p(x) = \int p(x, z) dz p(x)=∫p(x,z)dz 是个高维积分,通常无法直接算出来(就像拼图太复杂,拼不完)。其他方法(如 MCMC 采样)虽然能近似,但计算量大,速度慢。变分推断把这个问题变成一个优化问题,用数学方法快速找到一个“足够好”的近似。
1.3 变分推断的核心思想
我们用一个简单的分布 q ( z ) q(z) q(z)(比如高斯分布或 Dirichlet 分布)来近似 p ( z ∣ x ) p(z|x) p(z∣x)。为了衡量 q ( z ) q(z) q(z) 和 p ( z ∣ x ) p(z|x) p(z∣x) 有多接近,我们用KL 散度(Kullback-Leibler 散度)来量化差异:
KL ( q ( z ) ∣ ∣ p ( z ∣ x ) ) = E q ( z ) [ log q ( z ) − log p ( z ∣ x ) ] \text{KL}(q(z) || p(z|x)) = \mathbb{E}_{q(z)}[\log q(z) - \log p(z|x)] KL(q(z)∣∣p(z∣x))=Eq(z)[logq(z)−logp(z∣x)]
KL 散度越小, q ( z ) q(z) q(z) 越接近 p ( z ∣ x ) p(z|x) p(z∣x)。但由于 p ( z ∣ x ) p(z|x) p(z∣x) 难以直接计算,我们转而优化一个等价的目标:证据下界(ELBO)。
ELBO 是什么?
ELBO 是对数边际似然 log p ( x ) \log p(x) logp(x) 的下界(lower bound)。通过最大化 ELBO,我们可以让 q ( z ) q(z) q(z) 尽量接近 p ( z ∣ x ) p(z|x) p(z∣x)。ELBO 的公式是:
ELBO = E q ( z ) [ log p ( x , z ) ] − E q ( z ) [ log q ( z ) ] \text{ELBO} = \mathbb{E}_{q(z)}[\log p(x, z)] - \mathbb{E}_{q(z)}[\log q(z)] ELBO=Eq(z)[logp(x,z)]−Eq(z)[logq(z)]
- 第一项 E q ( z ) [ log p ( x , z ) ] \mathbb{E}_{q(z)}[\log p(x, z)] Eq(z)[logp(x,z)]:衡量 q ( z ) q(z) q(z) 下的模型对数据的拟合程度。
- 第二项 − E q ( z ) [ log q ( z ) ] -\mathbb{E}_{q(z)}[\log q(z)] −Eq(z)[logq(z)]:是 q ( z ) q(z) q(z) 的熵,鼓励 q ( z ) q(z) q(z) 不要过于集中。
1.4 数学推导(一步步拆解)
为了让你更清楚 ELBO 的来源,我们慢慢推导:
- 从对数边际似然开始:
log p ( x ) = log ∫ p ( x , z ) d z \log p(x) = \log \int p(x, z) dz logp(x)=log∫p(x,z)dz
这是一个积分,很难直接算。 - 引入 q ( z ) q(z) q(z):
我们引入一个分布 q ( z ) q(z) q(z),将积分改写为期望:
log p ( x ) = log ∫ p ( x , z ) q ( z ) q ( z ) d z = log E q ( z ) [ p ( x , z ) q ( z ) ] \log p(x) = \log \int p(x, z) \frac{q(z)}{q(z)} dz = \log \mathbb{E}_{q(z)} \left[ \frac{p(x, z)}{q(z)} \right] logp(x)=log∫p(x,z)q(z)q(z)dz=logEq(z)[q(z)p(x,z)] - 应用 Jensen 不等式:
因为 log \log log 是凹函数,Jensen 不等式告诉我们:
log E q ( z ) [ p ( x , z ) q ( z ) ] ≥ E q ( z