变分扩散模型 ELBO 的推导过程详解
变分扩散模型(Variational Diffusion Model)通过证据下界(Evidence Lower Bound, ELBO)优化模型参数,实现从噪声到数据的生成。ELBO 的具体形式,其推导基于概率论中的边缘化、Jensen 不等式以及条件独立的性质。本文将详细推导 ELBO 表达式 ( ELBO φ , θ ( x ) \text{ELBO}_{φ,θ}(x) ELBOφ,θ(x) ),从基本的对数似然分解开始,逐步展开每一步计算,面向具备概率论和深度学习基础的读者。
ELBO 的目标与初始设定
问题背景
目标是最大化 ( log p ( x ) \log p(x) logp(x) ),其中 ( x = x 0 x = x_0 x=x0 ) 是输入数据,( p ( x 0 ) p(x_0) p(x0) ) 是数据的真实分布。扩散模型通过多步正向过程 ( q φ ( x 1 : T ∣ x 0 ) q_φ(x_{1:T}|x_0) qφ(x1:T∣x0) ) 和逆向过程 ( p θ ( x 0 : T ) p_θ(x_{0:T}) pθ(x0:T) ) 建模这一分布。ELBO 提供了一个可优化的下界:
ELBO φ , θ ( x ) = E q φ ( x 1 ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] − E q φ ( x T − 1 ∣ x 0 ) [ D K L ( q φ ( x T ∣ x T − 1 ) ∥ p ( x T ) ) ] − ∑ t = 1 T − 1 E q φ ( x t − 1 , x t + 1 ∣ x 0 ) [ D K L ( q φ ( x t ∣ x t − 1 ) ∥ p θ ( x t ∣ x t + 1 ) ) ] \text{ELBO}_{φ,θ}(x) = \mathbb{E}_{q_φ(x_1|x_0)} [\log p_θ(x_0|x_1)] - \mathbb{E}_{q_φ(x_{T-1}|x_0)} [D_{KL}(q_φ(x_T|x_{T-1}) \| p(x_T))] - \sum_{t=1}^{T-1} \mathbb{E}_{q_φ(x_{t-1}, x_{t+1}|x_0)} [D_{KL}(q_φ(x_t|x_{t-1}) \| p_θ(x_t|x_{t+1}))] ELBOφ,θ(x)=Eqφ(x1∣x0)[logpθ(x0∣x1)]−Eqφ(xT−1∣x0)[DKL(qφ(xT∣xT−1)∥p(xT))]−t=1∑T−1Eqφ(xt−1,xt+1∣x0)[DKL(qφ(xt∣xt−1)∥pθ(xt∣xt+1))]
其中 (
x
T
∼
N
(
0
,
I
)
x_T \sim \mathcal{N}(0, I)
xT∼N(0,I) )。
直观的解释可以参考笔者的另一篇博客:变分扩散模型中的 Evidence Lower Bound (ELBO) 详解
下面我们推导这一表达式的来源。
推导步骤
步骤 1:对数似然的边缘化
从贝叶斯定理出发,( log p ( x 0 ) \log p(x_0) logp(x0) )(即 ( log p ( x ) \log p(x) logp(x) ))可以通过对所有中间状态 ( x 1 : T x_{1:T} x1:T ) 积分获得:
log p ( x 0 ) = log ∫ p ( x 0 : T ) d x 1 : T \log p(x_0) = \log \int p(x_{0:T}) \, dx_{1:T} logp(x0)=log∫p(x0:T)dx1:T
这里 ( x 0 : T = { x 0 , x 1 , … , x T } x_{0:T} = \{x_0, x_1, \dots, x_T\} x0:T={x0,x1,…,xT} ) 是从 ( t = 0 t=0 t=0 ) 到 ( t = T t=T t=T ) 的所有状态,( p ( x 0 : T ) p(x_{0:T}) p(x0:T) ) 是联合逆向分布。
步骤 2:引入变分分布
为了便于优化,引入正向过程 ( q φ ( x 1 : T ∣ x 0 ) q_φ(x_{1:T}|x_0) qφ(x1:T∣x0) ) 作为辅助分布。利用分母-分子技巧:
log p ( x 0 ) = log ∫ p ( x 0 : T ) q φ ( x 1 : T ∣ x 0 ) q φ ( x 1 : T ∣ x 0 ) d x 1 : T \log p(x_0) = \log \int \frac{p(x_{0:T})}{q_φ(x_{1:T}|x_0)} q_φ(x_{1:T}|x_0) \, dx_{1:T} logp(x0)=log∫qφ(x1:T∣x0)p(x0:T)qφ(x1:T∣x0)dx1:T
重排:
= log E q φ ( x 1 : T ∣ x 0 ) [ p ( x 0 : T ) q φ ( x 1 : T ∣ x 0 ) ] = \log \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \frac{p(x_{0:T})}{q_φ(x_{1:T}|x_0)} \right] =logEqφ(x1:T∣x0)[qφ(x1:T∣x0)p(x0:T)]
步骤 3:应用 Jensen 不等式
( log \log log ) 是一个凹函数,根据 Jensen 不等式:
log E q φ ( x 1 : T ∣ x 0 ) [ p ( x 0 : T ) q φ ( x 1 : T ∣ x 0 ) ] ≥ E q φ ( x 1 : T ∣ x 0 ) [ log p ( x 0 : T ) q φ ( x 1 : T ∣ x 0 ) ] \log \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \frac{p(x_{0:T})}{q_φ(x_{1:T}|x_0)} \right] \geq \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log \frac{p(x_{0:T})}{q_φ(x_{1:T}|x_0)} \right] logEqφ(x1:T∣x0)[qφ(x1:T∣x0)p(x0:T)]≥Eqφ(x1:T∣x0)[logqφ(x1:T∣x0)p(x0:T)]
等号成立当且仅当 ( p ( x 0 : T ) q φ ( x 1 : T ∣ x 0 ) \frac{p(x_{0:T})}{q_φ(x_{1:T}|x_0)} qφ(x1:T∣x0)p(x0:T) ) 恒定(即两者完全匹配)。因此:
log p ( x 0 ) ≥ E q φ ( x 1 : T ∣ x 0 ) [ log p ( x 0 : T ) q φ ( x 1 : T ∣ x 0 ) ] \log p(x_0) \geq \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log \frac{p(x_{0:T})}{q_φ(x_{1:T}|x_0)} \right] logp(x0)≥Eqφ(x1:T∣x0)[logqφ(x1:T∣x0)p(x0:T)]
右边即为 ELBO。
步骤 4:分解联合分布
逆向分布 ( p ( x 0 : T ) p(x_{0:T}) p(x0:T) )
根据马尔可夫性质,逆向过程是条件独立的,可知:
p ( x 0 : T ) = p ( x T ) ∏ t = 1 T p ( x t − 1 ∣ x t ) p(x_{0:T}) = p(x_T) \prod_{t=1}^T p(x_{t-1}|x_t) p(x0:T)=p(xT)t=1∏Tp(xt−1∣xt)
- ( p ( x T ) p(x_T) p(xT) ) 是先验分布,通常设为 ( N ( 0 , I ) \mathcal{N}(0, I) N(0,I) )。
- ( p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt−1∣xt) ) 是逆向过渡分布,参数化为 ( p θ ( x t − 1 ∣ x t ) p_θ(x_{t-1}|x_t) pθ(xt−1∣xt) )。
展开:
p ( x 0 : T ) = p ( x T ) p ( x 0 ∣ x 1 ) ∏ t = 2 T p ( x t − 1 ∣ x t ) p(x_{0:T}) = p(x_T) p(x_0|x_1) \prod_{t=2}^T p(x_{t-1}|x_t) p(x0:T)=p(xT)p(x0∣x1)t=2∏Tp(xt−1∣xt)
正向分布 ( q φ ( x 1 : T ∣ x 0 ) q_φ(x_{1:T}|x_0) qφ(x1:T∣x0) )
同样,( q φ ( x 1 : T ∣ x 0 ) q_φ(x_{1:T}|x_0) qφ(x1:T∣x0) ) 是马尔可夫链:
q φ ( x 1 : T ∣ x 0 ) = ∏ t = 1 T q φ ( x t ∣ x t − 1 ) q_φ(x_{1:T}|x_0) = \prod_{t=1}^T q_φ(x_t|x_{t-1}) qφ(x1:T∣x0)=t=1∏Tqφ(xt∣xt−1)
- ( q φ ( x 1 ∣ x 0 ) q_φ(x_1|x_0) qφ(x1∣x0)) 是从 ( x 0 x_0 x0 ) 到 ( x 1 x_1 x1 ) 的过渡。
- ( q φ ( x T ∣ x T − 1 ) q_φ(x_T|x_{T-1}) qφ(xT∣xT−1) ) 是最后一步。
因此:
q φ ( x 1 : T ∣ x 0 ) = q φ ( x T ∣ x T − 1 ) ∏ t = 1 T − 1 q φ ( x t ∣ x t − 1 ) q_φ(x_{1:T}|x_0) = q_φ(x_T|x_{T-1}) \prod_{t=1}^{T-1} q_φ(x_t|x_{t-1}) qφ(x1:T∣x0)=qφ(xT∣xT−1)t=1∏T−1qφ(xt∣xt−1)
步骤 5:代入 ELBO 表达式
将 ( p ( x 0 : T ) p(x_{0:T}) p(x0:T) ) 和 ( q φ ( x 1 : T ∣ x 0 ) q_φ(x_{1:T}|x_0) qφ(x1:T∣x0) ) 代入:
log p ( x 0 ) ≥ E q φ ( x 1 : T ∣ x 0 ) [ log p ( x T ) p ( x 0 ∣ x 1 ) ∏ t = 2 T p ( x t − 1 ∣ x t ) q φ ( x T ∣ x T − 1 ) ∏ t = 1 T − 1 q φ ( x t ∣ x t − 1 ) ] \log p(x_0) \geq \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log \frac{p(x_T) p(x_0|x_1) \prod_{t=2}^T p(x_{t-1}|x_t)}{q_φ(x_T|x_{T-1}) \prod_{t=1}^{T-1} q_φ(x_t|x_{t-1})} \right] logp(x0)≥Eqφ(x1:T∣x0)[logqφ(xT∣xT−1)∏t=1T−1qφ(xt∣xt−1)p(xT)p(x0∣x1)∏t=2Tp(xt−1∣xt)]
分母-分子展开:
= E q φ ( x 1 : T ∣ x 0 ) [ log [ p ( x T ) q φ ( x T ∣ x T − 1 ) ⋅ p ( x 0 ∣ x 1 ) ∏ t = 1 T − 1 q φ ( x t ∣ x t − 1 ) ⋅ ∏ t = 2 T p ( x t − 1 ∣ x t ) ] ] = \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log [\frac{p(x_T)}{q_φ(x_T|x_{T-1})} \cdot \frac{p(x_0|x_1)}{\prod_{t=1}^{T-1} q_φ(x_t|x_{t-1})} \cdot \prod_{t=2}^T p(x_{t-1}|x_t) ]\right] =Eqφ(x1:T∣x0)[log[qφ(xT∣xT−1)p(xT)⋅∏t=1T−1qφ(xt∣xt−1)p(x0∣x1)⋅t=2∏Tp(xt−1∣xt)]]
步骤 6:分离期望
由于期望是线性的,可分离为:
= E q φ ( x 1 : T ∣ x 0 ) [ log p ( x T ) q φ ( x T ∣ x T − 1 ) ] + E q φ ( x 1 : T ∣ x 0 ) [ log p ( x 0 ∣ x 1 ) ] + E q φ ( x 1 : T ∣ x 0 ) [ log ∏ t = 2 T p ( x t − 1 ∣ x t ) ∏ t = 1 T − 1 q φ ( x t ∣ x t − 1 ) ] = \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log \frac{p(x_T)}{q_φ(x_T|x_{T-1})} \right] + \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log p(x_0|x_1) \right] + \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log \frac{\prod_{t=2}^T p(x_{t-1}|x_t)}{\prod_{t=1}^{T-1} q_φ(x_t|x_{t-1})} \right] =Eqφ(x1:T∣x0)[logqφ(xT∣xT−1)p(xT)]+Eqφ(x1:T∣x0)[logp(x0∣x1)]+Eqφ(x1:T∣x0)[log∏t=1T−1qφ(xt∣xt−1)∏t=2Tp(xt−1∣xt)]
调整索引
将 ( ∏ t = 2 T p ( x t − 1 ∣ x t ) \prod_{t=2}^T p(x_{t-1}|x_t) ∏t=2Tp(xt−1∣xt) ) 的 ( t t t ) 替换为 ( t + 1 t+1 t+1 ):
∏ t = 2 T p ( x t − 1 ∣ x t ) = ∏ t = 1 T − 1 p ( x t ∣ x t + 1 ) \prod_{t=2}^T p(x_{t-1}|x_t) = \prod_{t=1}^{T-1} p(x_t|x_{t+1}) t=2∏Tp(xt−1∣xt)=t=1∏T−1p(xt∣xt+1)
因此:
∏ t = 2 T p ( x t − 1 ∣ x t ) ∏ t = 1 T − 1 q φ ( x t ∣ x t − 1 ) = ∏ t = 1 T − 1 p ( x t ∣ x t + 1 ) q φ ( x t ∣ x t − 1 ) \frac{\prod_{t=2}^T p(x_{t-1}|x_t)}{\prod_{t=1}^{T-1} q_φ(x_t|x_{t-1})} = \prod_{t=1}^{T-1} \frac{p(x_t|x_{t+1})}{q_φ(x_t|x_{t-1})} ∏t=1T−1qφ(xt∣xt−1)∏t=2Tp(xt−1∣xt)=t=1∏T−1qφ(xt∣xt−1)p(xt∣xt+1)
步骤 7:简化期望
- 重构项:
E q φ ( x 1 : T ∣ x 0 ) [ log p ( x 0 ∣ x 1 ) ] \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log p(x_0|x_1) \right] Eqφ(x1:T∣x0)[logp(x0∣x1)]
因为 ( log p ( x 0 ∣ x 1 ) \log p(x_0|x_1) logp(x0∣x1) ) 只依赖 ( x 0 x_0 x0 ) 和 ( x 1 x_1 x1 ),期望可简化为:
= E q φ ( x 1 ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] = \mathbb{E}_{q_φ(x_1|x_0)} \left[ \log p_θ(x_0|x_1) \right] =Eqφ(x1∣x0)[logpθ(x0∣x1)]
- 先验匹配项:
E q φ ( x 1 : T ∣ x 0 ) [ log p ( x T ) q φ ( x T ∣ x T − 1 ) ] \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log \frac{p(x_T)}{q_φ(x_T|x_{T-1})} \right] Eqφ(x1:T∣x0)[logqφ(xT∣xT−1)p(xT)]
( log p ( x T ) q φ ( x T ∣ x T − 1 ) \log \frac{p(x_T)}{q_φ(x_T|x_{T-1})} logqφ(xT∣xT−1)p(xT) ) 只依赖 ( x T x_T xT ) 和 ( x T − 1 x_{T-1} xT−1 )。由于 ( q φ ( x 1 : T ∣ x 0 ) q_φ(x_{1:T}|x_0) qφ(x1:T∣x0) ) 是马尔可夫链,( q φ ( x T ∣ x T − 1 , x 0 ) = q φ ( x T ∣ x T − 1 ) q_φ(x_T|x_{T-1}, x_0) = q_φ(x_T|x_{T-1}) qφ(xT∣xT−1,x0)=qφ(xT∣xT−1) ),期望可写为:
= E q φ ( x T − 1 ∣ x 0 ) [ E q φ ( x T ∣ x T − 1 ) [ log p ( x T ) q φ ( x T ∣ x T − 1 ) ] ] = \mathbb{E}_{q_φ(x_{T-1}|x_0)} \left[ \mathbb{E}_{q_φ(x_T|x_{T-1})} \left[ \log \frac{p(x_T)}{q_φ(x_T|x_{T-1})} \right] \right] =Eqφ(xT−1∣x0)[Eqφ(xT∣xT−1)[logqφ(xT∣xT−1)p(xT)]]
因为 ( log p ( x T ) q φ ( x T ∣ x T − 1 ) \log \frac{p(x_T)}{q_φ(x_T|x_{T-1})} logqφ(xT∣xT−1)p(xT) ) 是 KL 散度:
= − E q φ ( x T − 1 ∣ x 0 ) [ D K L ( q φ ( x T ∣ x T − 1 ) ∥ p ( x T ) ) ] = -\mathbb{E}_{q_φ(x_{T-1}|x_0)} \left[ D_{KL}(q_φ(x_T|x_{T-1}) \| p(x_T)) \right] =−Eqφ(xT−1∣x0)[DKL(qφ(xT∣xT−1)∥p(xT))]
- 过渡项:
E q φ ( x 1 : T ∣ x 0 ) [ log ∏ t = 1 T − 1 p ( x t ∣ x t + 1 ) q φ ( x t ∣ x t − 1 ) ] \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log \prod_{t=1}^{T-1} \frac{p(x_t|x_{t+1})}{q_φ(x_t|x_{t-1})} \right] Eqφ(x1:T∣x0)[logt=1∏T−1qφ(xt∣xt−1)p(xt∣xt+1)]
利用对数的性质:
= ∑ t = 1 T − 1 E q φ ( x 1 : T ∣ x 0 ) [ log p ( x t ∣ x t + 1 ) q φ ( x t ∣ x t − 1 ) ] = \sum_{t=1}^{T-1} \mathbb{E}_{q_φ(x_{1:T}|x_0)} \left[ \log \frac{p(x_t|x_{t+1})}{q_φ(x_t|x_{t-1})} \right] =t=1∑T−1Eqφ(x1:T∣x0)[logqφ(xt∣xt−1)p(xt∣xt+1)]
因为 ( log p ( x t ∣ x t + 1 ) q φ ( x t ∣ x t − 1 ) \log \frac{p(x_t|x_{t+1})}{q_φ(x_t|x_{t-1})} logqφ(xt∣xt−1)p(xt∣xt+1) ) 只依赖 ( x t − 1 , x t , x t + 1 x_{t-1}, x_t, x_{t+1} xt−1,xt,xt+1 ):
= ∑ t = 1 T − 1 E q φ ( x t − 1 , x t , x t + 1 ∣ x 0 ) [ log p ( x t ∣ x t + 1 ) q φ ( x t ∣ x t − 1 ) ] = \sum_{t=1}^{T-1} \mathbb{E}_{q_φ(x_{t-1}, x_t, x_{t+1}|x_0)} \left[ \log \frac{p(x_t|x_{t+1})}{q_φ(x_t|x_{t-1})} \right] =t=1∑T−1Eqφ(xt−1,xt,xt+1∣x0)[logqφ(xt∣xt−1)p(xt∣xt+1)]
由于 ( q φ ( x t ∣ x 0 ) = q φ ( x t ∣ x t − 1 ) q_φ(x_t|x_0) = q_φ(x_t|x_{t-1}) qφ(xt∣x0)=qφ(xt∣xt−1) )(马尔可夫性):
= ∑ t = 1 T − 1 E q φ ( x t − 1 , x t + 1 ∣ x 0 ) [ E q φ ( x t ∣ x t − 1 ) [ log p ( x t ∣ x t + 1 ) q φ ( x t ∣ x t − 1 ) ] ] = \sum_{t=1}^{T-1} \mathbb{E}_{q_φ(x_{t-1}, x_{t+1}|x_0)} \left[ \mathbb{E}_{q_φ(x_t|x_{t-1})} \left[ \log \frac{p(x_t|x_{t+1})}{q_φ(x_t|x_{t-1})} \right] \right] =t=1∑T−1Eqφ(xt−1,xt+1∣x0)[Eqφ(xt∣xt−1)[logqφ(xt∣xt−1)p(xt∣xt+1)]]
这即 KL 散度:
= − ∑ t = 1 T − 1 E q φ ( x t − 1 , x t + 1 ∣ x 0 ) [ D K L ( q φ ( x t ∣ x t − 1 ) ∥ p θ ( x t ∣ x t + 1 ) ) ] = -\sum_{t=1}^{T-1} \mathbb{E}_{q_φ(x_{t-1}, x_{t+1}|x_0)} \left[ D_{KL}(q_φ(x_t|x_{t-1}) \| p_θ(x_t|x_{t+1})) \right] =−t=1∑T−1Eqφ(xt−1,xt+1∣x0)[DKL(qφ(xt∣xt−1)∥pθ(xt∣xt+1))]
步骤 8:参数化
将 ( p ( x 0 ∣ x 1 ) p(x_0|x_1) p(x0∣x1) ) 和 ( p ( x t ∣ x t + 1 ) p(x_t|x_{t+1}) p(xt∣xt+1)) 替换为可学习的 ( p θ ( x 0 ∣ x 1 ) p_θ(x_0|x_1) pθ(x0∣x1) ) 和 ( p θ ( x t ∣ x t + 1 ) p_θ(x_t|x_{t+1}) pθ(xt∣xt+1) ),得到最终 ELBO。
推导总结
- 从 ( log p ( x 0 ) \log p(x_0) logp(x0) ) 的积分形式出发,引入 ( q φ ( x 1 : T ∣ x 0 ) q_φ(x_{1:T}|x_0) qφ(x1:T∣x0) ) 并应用 Jensen 不等式,得到 ELBO 下界。
- 通过分解 ( p ( x 0 : T ) p(x_{0:T}) p(x0:T) ) 和 ( q φ ( x 1 : T ∣ x 0 ) q_φ(x_{1:T}|x_0) qφ(x1:T∣x0) ) 的马尔可夫结构,分离出初始重构、先验匹配和过渡一致性项。
- 期望的简化利用了条件独立性,确保每项只依赖相关变量。
代码实现片段(伪代码)
def elbo_loss(x0, model, T, alpha_schedule):
elbo = 0.0
x1 = forward_transition(x0, alpha_schedule[1])
elbo += torch.mean(model.log_prob_x0_given_x1(x0, x1)) # Initial block
xT_minus_1 = forward_multi_step(x0, alpha_schedule[:T])
xT = forward_transition(xT_minus_1, alpha_schedule[T])
kl_final = kl_divergence(xT, torch.zeros_like(xT), torch.ones_like(xT))
elbo -= torch.mean(kl_final) # Final block
for t in range(1, T):
x_t_minus_1, x_t_plus_1 = sample_pair(x0, t, alpha_schedule)
x_t = forward_transition(x_t_minus_1, alpha_schedule[t])
x_t_reverse = model.reverse(x_t_plus_1, t)
kl_trans = kl_divergence(x_t, x_t_reverse.mean, x_t_reverse.cov)
elbo -= torch.mean(kl_trans) # Transition blocks
return elbo
总结
ELBO 的推导从对数似然的下界出发,通过 Jensen 不等式和马尔可夫性质,分解为三个优化目标:初始重构、先验匹配和过渡一致性。这一过程体现了扩散模型的概率建模本质,为训练逆向去噪过程提供了理论基础。
希望这篇推导加深了你的理解!
后记
2025年3月5日15点32分于上海,在grok 3大模型辅助下完成。