深入理解连续时间变换:Normalizing Flows 的动态视角
在《Normalizing Flows for Probabilistic Modeling and Inference》第4节中,作者将目光转向连续时间变换(Continuous-Time Transformations),为normalizing flows提供了一个全新的构建视角。相比第3章的有限变换组合,连续时间流通过微分方程描述概率分布的演化,带来了更高的灵活性和理论深度。本文将重点剖析4.1小节“Definition”的数学定义和公式,帮助你理解这一动态框架的本质,而对4.2小节的求解与优化部分则简要带过。
以前的博客内容:
Normalizing Flows(一)入门:从基本定义到深度学习应用
Normalizing Flows(二)的表达能力:它有多“万能”?
Normalizing Flows(三)的建模与推断之道:从理论到实践
Autoregressive Flow:Normalizing Flows(四)的自回归之路
深入解析 Contractive Residual Flows(收缩残差流):Normalizing Flows(五)之残差流的收缩之美
Normalizing Flows(六)之基于矩阵行列式引理的残差流:Planar, Sylvester 和 Radial Flows
Paper: https://arxiv.org/pdf/1912.02762
4.1 Definition:连续时间流的数学基础
连续时间变换的核心思想是将概率分布的变换视为一个随时间演化的动态过程,而非离散的步骤序列。假设我们有一个初始分布
p
z
0
(
z
0
)
p_{\mathbf{z}_0}(\mathbf{z}_0)
pz0(z0),通过一个连续的变换过程,演化到目标分布
p
z
t
(
z
t
)
p_{\mathbf{z}_t}(\mathbf{z}_t)
pzt(zt),其中
t
∈
[
0
,
T
]
t \in [0, T]
t∈[0,T] 表示时间。变换由一个向量场驱动,具体形式为常微分方程(ODE):
d
z
(
t
)
d
t
=
f
(
z
(
t
)
,
t
;
ϕ
)
,
\frac{d\mathbf{z}(t)}{dt} = f(\mathbf{z}(t), t; \phi),
dtdz(t)=f(z(t),t;ϕ),
其中:
- z ( t ) ∈ R D \mathbf{z}(t) \in \mathbb{R}^D z(t)∈RD 是时间 t t t 处的状态变量,
- f : R D × R → R D f: \mathbb{R}^D \times \mathbb{R} \to \mathbb{R}^D f:RD×R→RD 是一个参数化的向量场(通常由神经网络实现),参数为 ϕ \phi ϕ,
- 初始条件为 z ( 0 ) = u \mathbf{z}(0) = \mathbf{u} z(0)=u, u ∼ p u ( u ) \mathbf{u} \sim p_{\mathbf{u}}(\mathbf{u}) u∼pu(u) 是基分布。
最终输出 z ( T ) = x \mathbf{z}(T) = \mathbf{x} z(T)=x,对应目标分布 p x ( x ) p_{\mathbf{x}}(\mathbf{x}) px(x)。这个过程可以看作从 u \mathbf{u} u 到 x \mathbf{x} x 的“流”(flow),由 f f f 定义的动态轨迹决定。
轨迹与可逆性
给定初始值
z
(
0
)
\mathbf{z}(0)
z(0),通过数值积分(如欧拉法或Runge-Kutta)求解上述ODE,可以得到
z
(
t
)
\mathbf{z}(t)
z(t) 的轨迹:
z
(
t
)
=
z
(
0
)
+
∫
0
t
f
(
z
(
s
)
,
s
;
ϕ
)
d
s
。
\mathbf{z}(t) = \mathbf{z}(0) + \int_0^t f(\mathbf{z}(s), s; \phi) \, ds。
z(t)=z(0)+∫0tf(z(s),s;ϕ)ds。
令
T
t
T_t
Tt 表示从时间 0 到
t
t
t 的变换,即
z
(
t
)
=
T
t
(
z
(
0
)
)
\mathbf{z}(t) = T_t(\mathbf{z}(0))
z(t)=Tt(z(0))。如果
f
f
f 满足Lipschitz连续性(通常通过神经网络设计保证),则根据ODE理论,
T
t
T_t
Tt 是可逆的,且其逆变换
T
t
−
1
T_t^{-1}
Tt−1 由反向ODE定义:
d
z
(
t
)
d
t
=
−
f
(
z
(
t
)
,
t
;
ϕ
)
,
z
(
T
)
=
x
。
\frac{d\mathbf{z}(t)}{dt} = -f(\mathbf{z}(t), t; \phi), \quad \mathbf{z}(T) = \mathbf{x}。
dtdz(t)=−f(z(t),t;ϕ),z(T)=x。
这意味着我们可以从
x
\mathbf{x}
x 回溯到
u
\mathbf{u}
u,保持normalizing flow的可逆性要求。
概率密度的演化
normalizing flow的关键在于计算目标密度
p
x
(
x
)
p_{\mathbf{x}}(\mathbf{x})
px(x)。连续时间流通过概率密度随时间的变化来实现这一点。考虑
z
(
t
)
\mathbf{z}(t)
z(t) 的密度
p
z
(
t
)
(
z
(
t
)
)
p_{\mathbf{z}(t)}(\mathbf{z}(t))
pz(t)(z(t)),其演化由连续性方程(continuity equation)描述:
∂
p
(
z
,
t
)
∂
t
=
−
∇
z
⋅
[
p
(
z
,
t
)
f
(
z
,
t
;
ϕ
)
]
,
\frac{\partial p(\mathbf{z}, t)}{\partial t} = -\nabla_{\mathbf{z}} \cdot [p(\mathbf{z}, t) f(\mathbf{z}, t; \phi)],
∂t∂p(z,t)=−∇z⋅[p(z,t)f(z,t;ϕ)],
其中
∇
z
⋅
\nabla_{\mathbf{z}} \cdot
∇z⋅ 表示散度算子。这一方程刻画了概率质量如何随向量场
f
f
f “流动”。然而,直接求解此偏微分方程(PDE)在高维中往往不可行,因此我们转向等价的密度变化公式。
密度变化的微分形式
更实用的方法是跟踪密度沿轨迹的变化。定义对数密度
log
p
(
z
(
t
)
,
t
)
\log p(\mathbf{z}(t), t)
logp(z(t),t),其时间导数为:
d
d
t
log
p
(
z
(
t
)
,
t
)
=
−
∇
z
⋅
f
(
z
(
t
)
,
t
;
ϕ
)
。
\frac{d}{dt} \log p(\mathbf{z}(t), t) = -\nabla_{\mathbf{z}} \cdot f(\mathbf{z}(t), t; \phi)。
dtdlogp(z(t),t)=−∇z⋅f(z(t),t;ϕ)。
证明如下:根据链式法则,
d
d
t
log
p
(
z
(
t
)
,
t
)
=
∂
log
p
∂
t
+
∂
log
p
∂
z
⋅
d
z
d
t
。
\frac{d}{dt} \log p(\mathbf{z}(t), t) = \frac{\partial \log p}{\partial t} + \frac{\partial \log p}{\partial \mathbf{z}} \cdot \frac{d\mathbf{z}}{dt}。
dtdlogp(z(t),t)=∂t∂logp+∂z∂logp⋅dtdz。
代入连续性方程
∂
p
∂
t
=
−
∇
z
⋅
(
p
f
)
\frac{\partial p}{\partial t} = -\nabla_{\mathbf{z}} \cdot (p f)
∂t∂p=−∇z⋅(pf) 和
d
z
d
t
=
f
\frac{d\mathbf{z}}{dt} = f
dtdz=f,并利用
∂
log
p
∂
z
=
1
p
∂
p
∂
z
\frac{\partial \log p}{\partial \mathbf{z}} = \frac{1}{p} \frac{\partial p}{\partial \mathbf{z}}
∂z∂logp=p1∂z∂p,有:
∂
log
p
∂
t
=
−
1
p
∇
z
⋅
(
p
f
)
,
∂
log
p
∂
z
⋅
f
=
1
p
∂
p
∂
z
⋅
f
。
\frac{\partial \log p}{\partial t} = -\frac{1}{p} \nabla_{\mathbf{z}} \cdot (p f), \quad \frac{\partial \log p}{\partial \mathbf{z}} \cdot f = \frac{1}{p} \frac{\partial p}{\partial \mathbf{z}} \cdot f。
∂t∂logp=−p1∇z⋅(pf),∂z∂logp⋅f=p1∂z∂p⋅f。
合并后:
d
d
t
log
p
=
−
1
p
[
∇
z
⋅
(
p
f
)
−
∂
p
∂
z
⋅
f
]
=
−
1
p
[
p
∇
z
⋅
f
+
f
⋅
∂
p
∂
z
−
f
⋅
∂
p
∂
z
]
=
−
∇
z
⋅
f
。
\frac{d}{dt} \log p = -\frac{1}{p} \left[ \nabla_{\mathbf{z}} \cdot (p f) - \frac{\partial p}{\partial \mathbf{z}} \cdot f \right] = -\frac{1}{p} \left[ p \nabla_{\mathbf{z}} \cdot f + f \cdot \frac{\partial p}{\partial \mathbf{z}} - f \cdot \frac{\partial p}{\partial \mathbf{z}} \right] = -\nabla_{\mathbf{z}} \cdot f。
dtdlogp=−p1[∇z⋅(pf)−∂z∂p⋅f]=−p1[p∇z⋅f+f⋅∂z∂p−f⋅∂z∂p]=−∇z⋅f。
沿轨迹积分,从
t
=
0
t=0
t=0 到
t
=
T
t=T
t=T:
log
p
x
(
x
)
=
log
p
u
(
u
)
−
∫
0
T
∇
z
⋅
f
(
z
(
t
)
,
t
;
ϕ
)
d
t
。
\log p_{\mathbf{x}}(\mathbf{x}) = \log p_{\mathbf{u}}(\mathbf{u}) - \int_0^T \nabla_{\mathbf{z}} \cdot f(\mathbf{z}(t), t; \phi) \, dt。
logpx(x)=logpu(u)−∫0T∇z⋅f(z(t),t;ϕ)dt。
这里
u
=
z
(
0
)
\mathbf{u} = \mathbf{z}(0)
u=z(0),
x
=
z
(
T
)
\mathbf{x} = \mathbf{z}(T)
x=z(T)。这与离散流的变量变换公式
p
x
(
x
)
=
p
u
(
u
)
∣
det
J
T
−
1
(
x
)
∣
p_{\mathbf{x}}(\mathbf{x}) = p_{\mathbf{u}}(\mathbf{u}) \left| \operatorname{det} J_{T^{-1}}(\mathbf{x}) \right|
px(x)=pu(u)∣detJT−1(x)∣ 等价,区别在于连续时间流用积分替代了离散的雅可比行列式。
雅可比行列式的联系
事实上, ∫ 0 T ∇ z ⋅ f ( z ( t ) , t ; ϕ ) d t \int_0^T \nabla_{\mathbf{z}} \cdot f(\mathbf{z}(t), t; \phi) \, dt ∫0T∇z⋅f(z(t),t;ϕ)dt 是 log ∣ det J T 0 T ( z ( 0 ) ) ∣ \log \left| \operatorname{det} J_{T_0^T}(\mathbf{z}(0)) \right| log detJT0T(z(0)) 的连续版本。离散流中, T = T K ∘ ⋯ ∘ T 1 T = T_K \circ \cdots \circ T_1 T=TK∘⋯∘T1 的行列式是对各步的累积,而连续流将其推广为时间上的积分。这种形式避免了直接计算 D × D D \times D D×D 雅可比矩阵的 O ( D 3 ) \mathcal{O}(D^3) O(D3) 复杂度,但需要高效估计散度。
为什么用连续时间?
与有限变换相比,连续时间流有几个优势:
- 灵活性:通过调整 f f f 的形式(如神经网络),可以建模复杂的动态轨迹。
- 理论深度:连接了物理学中的流体力学和统计力学(如Liouville方程)。
- 平滑性:避免离散步骤的突变,适合需要平滑变换的应用。
然而,计算 z ( T ) \mathbf{z}(T) z(T) 和散度积分需要数值方法,这引入了额外的复杂性。
4.2 Solving and Optimizing Continuous-Time Flows(简述)
对于熟悉神经ODE(Neural ODE)的读者,4.2小节的内容并不陌生。求解连续时间流依赖于数值积分器(如Runge-Kutta,可以参考笔者的另外的博客:RK-4(四阶 Runge-Kutta 方法):更精确的 ODE 求解利器),计算 z ( T ) \mathbf{z}(T) z(T) 和散度积分可以通过伴随方法(adjoint method)优化,复杂度为 O ( D ) \mathcal{O}(D) O(D) 加上积分步数,可以参考笔者的博客: “伴随敏感性方法”(Adjoint Sensitivity Method):“连续时间的反向传播”。优化则通常通过梯度下降调整 ϕ \phi ϕ,结合变分推断或最大似然估计。细节上,FFJORD(Grathwohl et al., 2019)等方法通过 Hutchinson 估计器加速散度计算,读者可参考原文或相关文献。
总结
连续时间变换通过ODE将normalizing flows从离散步骤扩展到动态演化,4.1节的定义奠定了其数学基础。核心公式 d z d t = f \frac{d\mathbf{z}}{dt} = f dtdz=f 和 log p x ( x ) = log p u ( u ) − ∫ ∇ ⋅ f d t \log p_{\mathbf{x}}(\mathbf{x}) = \log p_{\mathbf{u}}(\mathbf{u}) - \int \nabla \cdot f \, dt logpx(x)=logpu(u)−∫∇⋅fdt 揭示了概率密度如何随向量场流动。对于深度学习研究者,这不仅是一个建模工具,更是一个连接微分方程与概率推断的桥梁。想尝试实现?不妨从简单的向量场开始,探索连续流的无限可能!
后记
2025年4月3日21点14分于上海,在grok 3大模型辅助下完成。