SMLD 的随机微分方程(SDE):从离散噪声到连续扩散

SMLD 的随机微分方程(SDE):从离散噪声到连续扩散

在扩散模型(Diffusion Models)的大家庭中,Score-Matching Langevin Dynamics(SMLD)是一种基于分数匹配的生成方法,与 DDPM(Denoising Diffusion Probabilistic Models)并驾齐驱。虽然 SMLD 没有显式的前向扩散步骤,但它可以通过噪声尺度的递增构造一个隐含的扩散过程,并用随机微分方程(SDE)描述其前向和逆向采样。本篇博客将面向深度学习研究者,介绍 SMLD 的 SDE 表示,包括前向和逆向过程的推导及其与 DDPM 的联系。


SMLD 的前向过程:隐含的扩散

与 DDPM 不同,SMLD 没有明确定义从数据到噪声的前向扩散步骤。它的训练基于一系列噪声水平 ( σ 1 , σ 2 , … , σ N \sigma_1, \sigma_2, \dots, \sigma_N σ1,σ2,,σN)(通常递增),通过分数匹配学习数据分布。但我们可以假设一个隐含的扩散过程,形式为:
x i = x i − 1 + σ i 2 − σ i − 1 2 z i − 1 , z i − 1 ∼ N ( 0 , I ) ( 4.18 ) x_i = x_{i-1} + \sqrt{\sigma_i^2 - \sigma_{i-1}^2} z_{i-1}, \quad z_{i-1} \sim \mathcal{N}(0, I) \quad(4.18) xi=xi1+σi2σi12 zi1,zi1N(0,I)(4.18)

  • ( x 0 x_0 x0 ):原始数据。
  • ( σ i 2 \sigma_i^2 σi2 ):第 ( i i i ) 步的噪声方差。
  • ( σ i 2 − σ i − 1 2 \sigma_i^2 - \sigma_{i-1}^2 σi2σi12 ):噪声方差的增量。
方差验证

假设 ( x i − 1 x_{i-1} xi1 ) 的方差为 ( σ i − 1 2 \sigma_{i-1}^2 σi12):
Var ( x i ) = Var ( x i − 1 + σ i 2 − σ i − 1 2 z i − 1 ) \text{Var}(x_i) = \text{Var}(x_{i-1} + \sqrt{\sigma_i^2 - \sigma_{i-1}^2} z_{i-1}) Var(xi)=Var(xi1+σi2σi12 zi1)
由于 ( z i − 1 z_{i-1} zi1 ) 是独立的高斯噪声:
Var ( x i ) = Var ( x i − 1 ) + Var ( σ i 2 − σ i − 1 2 z i − 1 ) = σ i − 1 2 + ( σ i 2 − σ i − 1 2 ) = σ i 2 \text{Var}(x_i) = \text{Var}(x_{i-1}) + \text{Var}(\sqrt{\sigma_i^2 - \sigma_{i-1}^2} z_{i-1}) = \sigma_{i-1}^2 + (\sigma_i^2 - \sigma_{i-1}^2) = \sigma_i^2 Var(xi)=Var(xi1)+Var(σi2σi12 zi1)=σi12+(σi2σi12)=σi2
这表明,公式 (4.18) 能逐步增加噪声方差,最终达到 ( σ N 2 \sigma_N^2 σN2)。

前向 SDE

将离散过程连续化,假设 ( σ i = σ ( i N ) \sigma_i = \sigma\left(\frac{i}{N}\right) σi=σ(Ni)) 是连续时间函数 ( σ ( t ) \sigma(t) σ(t)) 的采样,( t ∈ [ 0 , 1 ] t \in [0, 1] t[0,1] )。离散更新为:
x ( t + Δ t ) = x ( t ) + σ ( t + Δ t ) 2 − σ ( t ) 2 z ( t ) x(t + \Delta t) = x(t) + \sqrt{\sigma(t + \Delta t)^2 - \sigma(t)^2} z(t) x(t+Δt)=x(t)+σ(t+Δt)2σ(t)2 z(t)
近似 (下文有解释):
σ ( t + Δ t ) 2 − σ ( t ) 2 ≈ d [ σ ( t ) 2 ] d t Δ t \sigma(t + \Delta t)^2 - \sigma(t)^2 \approx \frac{d[\sigma(t)^2]}{dt} \Delta t σ(t+Δt)2σ(t)2dtd[σ(t)2]Δt
所以:
x ( t + Δ t ) ≈ x ( t ) + d [ σ ( t ) 2 ] d t Δ t z ( t ) x(t + \Delta t) \approx x(t) + \sqrt{\frac{d[\sigma(t)^2]}{dt} \Delta t} z(t) x(t+Δt)x(t)+dtd[σ(t)2]Δt z(t)
当 ( Δ t → 0 \Delta t \to 0 Δt0):
x ( t + Δ t ) − x ( t ) Δ t ≈ d [ σ ( t ) 2 ] d t z ( t ) Δ t \frac{x(t + \Delta t) - x(t)}{\Delta t} \approx \sqrt{\frac{d[\sigma(t)^2]}{dt}} \frac{z(t)}{\sqrt{\Delta t}} Δtx(t+Δt)x(t)dtd[σ(t)2] Δt z(t)
因为 ( z ( t ) Δ t = d w ( t ) Δ t \frac{z(t)}{\sqrt{\Delta t}} = \frac{dw(t)}{\Delta t} Δt z(t)=Δtdw(t) ),极限下:
d x = d [ σ ( t ) 2 ] d t   d w dx = \sqrt{\frac{d[\sigma(t)^2]}{dt}} \, dw dx=dtd[σ(t)2] dw
定理 4.3:SMLD 的前向采样 SDE 为:
d x = d [ σ ( t ) 2 ] d t   d w dx = \sqrt{\frac{d[\sigma(t)^2]}{dt}} \, dw dx=dtd[σ(t)2] dw

  • ( f ( x , t ) = 0 f(x, t) = 0 f(x,t)=0 ):无漂移项,仅靠噪声驱动。
  • ( g ( t ) = d [ σ ( t ) 2 ] d t g(t) = \sqrt{\frac{d[\sigma(t)^2]}{dt}} g(t)=dtd[σ(t)2] ):扩散项,控制噪声强度。

SMLD 的逆向过程:去噪采样

逆向过程从噪声 ( x ( 1 ) x(1) x(1) )(方差 ( σ N 2 \sigma_N^2 σN2))回到数据 ( x ( 0 ) x(0) x(0) )。根据逆向扩散的通用形式:
d x = [ f ( x , t ) − g ( t ) 2 ∇ x log ⁡ p t ( x ) ] d t + g ( t )   d w dx = [f(x, t) - g(t)^2 \nabla_x \log p_t(x)] dt + g(t) \, dw dx=[f(x,t)g(t)2xlogpt(x)]dt+g(t)dw
代入 SMLD 的前向参数:

  • ( f ( x , t ) = 0 f(x, t) = 0 f(x,t)=0 )
  • ( g ( t ) = d [ σ ( t ) 2 ] d t g(t) = \sqrt{\frac{d[\sigma(t)^2]}{dt}} g(t)=dtd[σ(t)2] )

得到:
d x = [ 0 − ( d [ σ ( t ) 2 ] d t ) 2 ∇ x log ⁡ p t ( x ) ] d t + d [ σ ( t ) 2 ] d t   d w dx = \left[ 0 - \left( \sqrt{\frac{d[\sigma(t)^2]}{dt}} \right)^2 \nabla_x \log p_t(x) \right] dt + \sqrt{\frac{d[\sigma(t)^2]}{dt}} \, dw dx= 0(dtd[σ(t)2] )2xlogpt(x) dt+dtd[σ(t)2] dw
即:
d x = − d [ σ ( t ) 2 ] d t ∇ x log ⁡ p t ( x )   d t + d [ σ ( t ) 2 ] d t   d w dx = -\frac{d[\sigma(t)^2]}{dt} \nabla_x \log p_t(x) \, dt + \sqrt{\frac{d[\sigma(t)^2]}{dt}} \, dw dx=dtd[σ(t)2]xlogpt(x)dt+dtd[σ(t)2] dw
定理 4.4:SMLD 的逆向采样 SDE 为:
d x = − d [ σ ( t ) 2 ] d t ∇ x log ⁡ p t ( x )   d t + d [ σ ( t ) 2 ] d t   d w dx = -\frac{d[\sigma(t)^2]}{dt} \nabla_x \log p_t(x) \, dt + \sqrt{\frac{d[\sigma(t)^2]}{dt}} \, dw dx=dtd[σ(t)2]xlogpt(x)dt+dtd[σ(t)2] dw

离散验证

令 ( α ( t ) = d [ σ ( t ) 2 ] d t \alpha(t) = \frac{d[\sigma(t)^2]}{dt} α(t)=dtd[σ(t)2]),逆向 SDE 为:
x ( t ) − x ( t + Δ t ) = − α ( t ) Δ t ∇ x log ⁡ p t ( x ) + α ( t ) Δ t z ( t ) x(t) - x(t + \Delta t) = -\alpha(t) \Delta t \nabla_x \log p_t(x) + \sqrt{\alpha(t) \Delta t} z(t) x(t)x(t+Δt)=α(t)Δtxlogpt(x)+α(t)Δt z(t)
时间反向:
x ( t − Δ t ) = x ( t ) + α ( t ) Δ t ∇ x log ⁡ p t ( x ) + α ( t ) Δ t z ( t ) x(t - \Delta t) = x(t) + \alpha(t) \Delta t \nabla_x \log p_t(x) + \sqrt{\alpha(t) \Delta t} z(t) x(tΔt)=x(t)+α(t)Δtxlogpt(x)+α(t)Δt z(t)
映射到离散:

  • ( x ( t ) = x i x(t) = x_i x(t)=xi )
  • ( x ( t − Δ t ) = x i − 1 x(t - \Delta t) = x_{i-1} x(tΔt)=xi1 )
  • ( α ( t ) Δ t = σ i 2 − σ i − 1 2 \alpha(t) \Delta t = \sigma_i^2 - \sigma_{i-1}^2 α(t)Δt=σi2σi12)

则:
x i − 1 = x i + ( σ i 2 − σ i − 1 2 ) ∇ x log ⁡ p i ( x i ) + σ i 2 − σ i − 1 2 z i x_{i-1} = x_i + (\sigma_i^2 - \sigma_{i-1}^2) \nabla_x \log p_i(x_i) + \sqrt{\sigma_i^2 - \sigma_{i-1}^2} z_i xi1=xi+(σi2σi12)xlogpi(xi)+σi2σi12 zi
这与 SMLD 的逆向迭代一致。


SMLD 与 DDPM 的对比
  • DDPM(VP SDE)

    • 前向:( d x = − β ( t ) 2 x   d t + β ( t )   d w dx = -\frac{\beta(t)}{2} x \, dt + \sqrt{\beta(t)} \, dw dx=2β(t)xdt+β(t) dw )
    • 逆向:( d x = − β ( t ) [ x 2 + ∇ x log ⁡ p t ( x ) ] d t + β ( t )   d w dx = -\beta(t) \left[ \frac{x}{2} + \nabla_x \log p_t(x) \right] dt + \sqrt{\beta(t)} \, dw dx=β(t)[2x+xlogpt(x)]dt+β(t) dw )
    • 特点:方差保持(Variance Preserving),有漂移项控制信号衰减。
  • SMLD(VE SDE)

    • 前向:( d x = d [ σ ( t ) 2 ] d t   d w dx = \sqrt{\frac{d[\sigma(t)^2]}{dt}} \, dw dx=dtd[σ(t)2] dw)
    • 逆向:( d x = − d [ σ ( t ) 2 ] d t ∇ x log ⁡ p t ( x )   d t + d [ σ ( t ) 2 ] d t   d w dx = -\frac{d[\sigma(t)^2]}{dt} \nabla_x \log p_t(x) \, dt + \sqrt{\frac{d[\sigma(t)^2]}{dt}} \, dw dx=dtd[σ(t)2]xlogpt(x)dt+dtd[σ(t)2] dw )
    • 特点:方差爆炸(Variance Exploding),无漂移,仅靠噪声累积。
等价性

Kawar et al. [https://arxiv.org/abs/2201.11793] 指出,VP 和 VE 的逆向推断过程在适当条件下等价。这意味着,在图像生成或修复等任务中,选择 VP(DDPM)还是 VE(SMLD)影响不大,但超参数(如 ( β ( t ) \beta(t) β(t)) 或 ( σ ( t ) \sigma(t) σ(t)))的选择会影响训练效果。


意义与应用
  1. 连续视角
    SMLD 的 SDE 表示将离散噪声添加过程升华为连续扩散,揭示其随机动态本质。

  2. 分数函数驱动
    逆向 SDE 中,( ∇ x log ⁡ p t ( x ) \nabla_x \log p_t(x) xlogpt(x))(分数函数)引导去噪,训练时用 ( s θ ( x i ) s_\theta(x_i) sθ(xi) ) 近似。

  3. 灵活采样
    可以用 SDE 求解器模拟逆向轨迹,也可沿用离散迭代。


总结

SMLD 的前向 SDE:
d x = d [ σ ( t ) 2 ] d t   d w dx = \sqrt{\frac{d[\sigma(t)^2]}{dt}} \, dw dx=dtd[σ(t)2] dw
逆向 SDE:
d x = − d [ σ ( t ) 2 ] d t ∇ x log ⁡ p t ( x )   d t + d [ σ ( t ) 2 ] d t   d w dx = -\frac{d[\sigma(t)^2]}{dt} \nabla_x \log p_t(x) \, dt + \sqrt{\frac{d[\sigma(t)^2]}{dt}} \, dw dx=dtd[σ(t)2]xlogpt(x)dt+dtd[σ(t)2] dw
分别描述了从数据到噪声的隐含扩散和从噪声到数据的去噪过程。与 DDPM 的 VP SDE 相比,SMLD 的 VE SDE 更强调方差的爆炸式增长。理解这些 SDE 为研究扩散模型的理论和优化提供了新视角,下一篇文章中,我们可以用 Python 模拟 SMLD 的 SDE 轨迹,敬请期待!


注:推导中简化了部分近似,注重直观理解。

为什么 ( σ ( t + Δ t ) 2 − σ ( t ) 2 ≈ d [ σ ( t ) 2 ] d t Δ t \sigma(t + \Delta t)^2 - \sigma(t)^2 \approx \frac{d[\sigma(t)^2]}{dt} \Delta t σ(t+Δt)2σ(t)2dtd[σ(t)2]Δt)?

推导近似

这个近似来源于微积分中的导数定义和泰勒展开。让我们一步步拆解。

1. 导数的定义

导数描述函数随自变量的变化率。对于函数 ( σ ( t ) 2 \sigma(t)^2 σ(t)2),其导数定义为:
d [ σ ( t ) 2 ] d t = lim ⁡ Δ t → 0 σ ( t + Δ t ) 2 − σ ( t ) 2 Δ t \frac{d[\sigma(t)^2]}{dt} = \lim_{\Delta t \to 0} \frac{\sigma(t + \Delta t)^2 - \sigma(t)^2}{\Delta t} dtd[σ(t)2]=Δt0limΔtσ(t+Δt)2σ(t)2
这意味着:
σ ( t + Δ t ) 2 − σ ( t ) 2 = d [ σ ( t ) 2 ] d t Δ t + 高阶余项 \sigma(t + \Delta t)^2 - \sigma(t)^2 = \frac{d[\sigma(t)^2]}{dt} \Delta t + \text{高阶余项} σ(t+Δt)2σ(t)2=dtd[σ(t)2]Δt+高阶余项
当 ( Δ t \Delta t Δt) 很小时,高阶余项(如 ( Δ t 2 \Delta t^2 Δt2) 项)可以忽略,近似为:
σ ( t + Δ t ) 2 − σ ( t ) 2 ≈ d [ σ ( t ) 2 ] d t Δ t \sigma(t + \Delta t)^2 - \sigma(t)^2 \approx \frac{d[\sigma(t)^2]}{dt} \Delta t σ(t+Δt)2σ(t)2dtd[σ(t)2]Δt

2. 泰勒展开的直观解释

更具体地,用泰勒展开表示 ( σ ( t + Δ t ) 2 \sigma(t + \Delta t)^2 σ(t+Δt)2):
σ ( t + Δ t ) 2 = σ ( t ) 2 + d [ σ ( t ) 2 ] d t Δ t + 1 2 d 2 [ σ ( t ) 2 ] d t 2 ( Δ t ) 2 + ⋯ \sigma(t + \Delta t)^2 = \sigma(t)^2 + \frac{d[\sigma(t)^2]}{dt} \Delta t + \frac{1}{2} \frac{d^2[\sigma(t)^2]}{dt^2} (\Delta t)^2 + \cdots σ(t+Δt)2=σ(t)2+dtd[σ(t)2]Δt+21dt2d2[σ(t)2](Δt)2+
两边减去 ( σ ( t ) 2 \sigma(t)^2 σ(t)2):
σ ( t + Δ t ) 2 − σ ( t ) 2 = d [ σ ( t ) 2 ] d t Δ t + 1 2 d 2 [ σ ( t ) 2 ] d t 2 ( Δ t ) 2 + 更高阶项 \sigma(t + \Delta t)^2 - \sigma(t)^2 = \frac{d[\sigma(t)^2]}{dt} \Delta t + \frac{1}{2} \frac{d^2[\sigma(t)^2]}{dt^2} (\Delta t)^2 + \text{更高阶项} σ(t+Δt)2σ(t)2=dtd[σ(t)2]Δt+21dt2d2[σ(t)2](Δt)2+更高阶项
当 ( Δ t → 0 \Delta t \to 0 Δt0)(即 ( N → ∞ N \to \infty N )),(( Δ t ) 2 \Delta t)^2 Δt)2) 及更高阶项变得非常小,可以忽略,只保留一阶项:
σ ( t + Δ t ) 2 − σ ( t ) 2 ≈ d [ σ ( t ) 2 ] d t Δ t \sigma(t + \Delta t)^2 - \sigma(t)^2 \approx \frac{d[\sigma(t)^2]}{dt} \Delta t σ(t+Δt)2σ(t)2dtd[σ(t)2]Δt

3. 物理直觉
  • ( σ ( t ) 2 \sigma(t)^2 σ(t)2) 是噪声方差随时间 ( t t t) 的函数。
  • ( σ ( t + Δ t ) 2 − σ ( t ) 2 \sigma(t + \Delta t)^2 - \sigma(t)^2 σ(t+Δt)2σ(t)2) 是时间间隔 ( Δ t \Delta t Δt) 内方差的增量。
  • ( d [ σ ( t ) 2 ] d t \frac{d[\sigma(t)^2]}{dt} dtd[σ(t)2]) 是方差的变化率,乘以 ( Δ t \Delta t Δt) 自然近似于增量。

就像速度 ( v = d x d t v = \frac{dx}{dt} v=dtdx ) 乘以时间 ( Δ t \Delta t Δt) 近似位移 ( Δ x = v Δ t \Delta x = v \Delta t Δx=vΔt),这里 ( d [ σ ( t ) 2 ] d t Δ t \frac{d[\sigma(t)^2]}{dt} \Delta t dtd[σ(t)2]Δt) 近似方差的变化。


注意事项
  • 高阶项的影响
    如果 ( Δ t \Delta t Δt) 不够小,二阶导数 ( d 2 [ σ ( t ) 2 ] d t 2 \frac{d^2[\sigma(t)^2]}{dt^2} dt2d2[σ(t)2]) 的贡献可能显著,近似会有偏差。
  • ( σ ( t ) \sigma(t) σ(t)) 的假设
    ( σ ( t ) \sigma(t) σ(t)) 需光滑可导,确保导数定义良好。

补充解释

布朗运动增量的正确定义

布朗运动 ( W ( t ) W(t) W(t) ) 的增量定义为:
W ( t + Δ t ) − W ( t ) ∼ N ( 0 , Δ t ) W(t + \Delta t) - W(t) \sim \mathcal{N}(0, \Delta t) W(t+Δt)W(t)N(0,Δt)
其微分形式为:
d W ( t ) = W ( t + Δ t ) − W ( t ) dW(t) = W(t + \Delta t) - W(t) dW(t)=W(t+Δt)W(t)
方差为 ( Δ t \Delta t Δt),标准差为 ( Δ t \sqrt{\Delta t} Δt )。如果用标准正态分布 ( z ( t ) ∼ N ( 0 , I ) z(t) \sim \mathcal{N}(0, I) z(t)N(0,I) ) 表示增量:
W ( t + Δ t ) − W ( t ) = Δ t z ( t ) W(t + \Delta t) - W(t) = \sqrt{\Delta t} z(t) W(t+Δt)W(t)=Δt z(t)
因为:
Var ( Δ t z ( t ) ) = ( Δ t ) 2 ⋅ Var ( z ( t ) ) = Δ t ⋅ 1 = Δ t \text{Var}(\sqrt{\Delta t} z(t)) = (\sqrt{\Delta t})^2 \cdot \text{Var}(z(t)) = \Delta t \cdot 1 = \Delta t Var(Δt z(t))=(Δt )2Var(z(t))=Δt1=Δt
所以:
d W ( t ) = Δ t z ( t ) dW(t) = \sqrt{\Delta t} z(t) dW(t)=Δt z(t)
这意味着:
d W ( t ) Δ t = Δ t z ( t ) Δ t = z ( t ) Δ t \frac{dW(t)}{\Delta t} = \frac{\sqrt{\Delta t} z(t)}{\Delta t} = \frac{z(t)}{\sqrt{\Delta t}} ΔtdW(t)=ΔtΔt z(t)=Δt z(t)

回到推导

回到:
x ( t + Δ t ) − x ( t ) Δ t = d [ σ ( t ) 2 ] d t z ( t ) Δ t \frac{x(t + \Delta t) - x(t)}{\Delta t} = \sqrt{\frac{d[\sigma(t)^2]}{dt}} \frac{z(t)}{\sqrt{\Delta t}} Δtx(t+Δt)x(t)=dtd[σ(t)2] Δt z(t)
右边正是:
d [ σ ( t ) 2 ] d t z ( t ) Δ t = d [ σ ( t ) 2 ] d t ⋅ d W ( t ) Δ t \sqrt{\frac{d[\sigma(t)^2]}{dt}} \frac{z(t)}{\sqrt{\Delta t}} = \sqrt{\frac{d[\sigma(t)^2]}{dt}} \cdot \frac{dW(t)}{\Delta t} dtd[σ(t)2] Δt z(t)=dtd[σ(t)2] ΔtdW(t)
当 ( Δ t → 0 \Delta t \to 0 Δt0):
x ( t + Δ t ) − x ( t ) Δ t → d x d t , d W ( t ) Δ t → d W d t \frac{x(t + \Delta t) - x(t)}{\Delta t} \to \frac{dx}{dt}, \quad \frac{dW(t)}{\Delta t} \to \frac{dW}{dt} Δtx(t+Δt)x(t)dtdx,ΔtdW(t)dtdW
于是:
d x d t = d [ σ ( t ) 2 ] d t d W d t \frac{dx}{dt} = \sqrt{\frac{d[\sigma(t)^2]}{dt}} \frac{dW}{dt} dtdx=dtd[σ(t)2] dtdW
微分形式:
d x = d [ σ ( t ) 2 ] d t   d W dx = \sqrt{\frac{d[\sigma(t)^2]}{dt}} \, dW dx=dtd[σ(t)2] dW
注意,这里 ( d W dW dW ) 是微分符号,习惯上写作 ( d w dw dw ),所以:
d x = d [ σ ( t ) 2 ] d t   d w dx = \sqrt{\frac{d[\sigma(t)^2]}{dt}} \, dw dx=dtd[σ(t)2] dw


参考

https://arxiv.org/pdf/2403.18103

后记

2025年3月9日16点29分于上海,在Grok 3大模型辅助下完成。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值