深入解析 Flow Matching(二):从条件概率路径与向量场到条件流匹配
在生成模型的研究中,Flow Matching(流匹配)是一种高效且优雅的训练方法,基于连续归一化流(Continuous Normalizing Flows, CNFs)。论文《Flow Matching for Generative Modeling》第 4 节“Conditional Probability Paths and Vector Fields”(条件概率路径与向量场)是理解这一方法的关键部分,它详细描述了如何通过条件概率路径和向量场构造生成过程。本文将以通俗的语言介绍这一节的内容,包括 4.1 节“Special Instances of Gaussian Conditional Probability Paths”(高斯条件概率路径的特殊实例),并给出 Theorem 3 的证明,帮助你快速把握来龙去脉。
请先参考笔者的第一篇博客:深入解析 Flow Matching:从条件概率路径与向量场到条件流匹配
paper link: https://arxiv.org/pdf/2210.02747
4 Conditional Probability Paths and Vector Fields 是什么?
想象一下,你想从一团随机噪声(比如标准正态分布)出发,逐步“雕刻”出真实数据的形状(比如图片或文本的分布)。Flow Matching 的核心就是通过一个平滑的“流动”过程实现这一点。第 4 节聚焦于如何为每个目标样本 ( x 1 x_1 x1 )(来自数据分布 ( q ( x 1 ) q(x_1) q(x1) ))设计一个专属的路径,让噪声分布逐步变成目标分布。
基本概念
-
条件概率 ( p t ( x ∣ x 1 ) p_t(x | x_1) pt(x∣x1) ):这是一个概率密度路径,描述在给定目标 ( x 1 x_1 x1 ) 的情况下,( x x x ) 在时间 ( t t t ) 的分布。
- ( t = 0 t = 0 t=0 ):起点是噪声,比如 ( p 0 ( x ∣ x 1 ) = N ( 0 , I ) p_0(x | x_1) = \mathcal{N}(0, I) p0(x∣x1)=N(0,I) )(标准正态分布)。
- ( t = 1 t = 1 t=1 ):终点接近 ( x 1 x_1 x1 ),比如 ( p 1 ( x ∣ x 1 ) p_1(x | x_1) p1(x∣x1) ) 几乎集中在 ( x 1 x_1 x1 ) 附近。
-
流 ( ψ t ( x ) \psi_t(x) ψt(x) ):一个变换函数,像流水线一样,把初始的 ( x x x)(从噪声中采样)推到时间 ( t t t ) 的位置。它通常是一个仿射变换,比如:
ψ t ( x ) = σ t ( x 1 ) x + μ t ( x 1 ) \psi_t(x) = \sigma_t(x_1) x + \mu_t(x_1) ψt(x)=σt(x1)x+μt(x1)
其中 ( σ t ( x 1 ) \sigma_t(x_1) σt(x1) ) 控制缩放,( μ t ( x 1 ) \mu_t(x_1) μt(x1) ) 控制平移。 -
向量场 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(x∣x1) ):流的“动力”,告诉 ( ψ t ( x ) \psi_t(x) ψt(x) ) 每时每刻该往哪儿走,满足:
d d t ψ t ( x ) = u t ( ψ t ( x ) ∣ x 1 ) \frac{d}{dt} \psi_t(x) = u_t(\psi_t(x) | x_1) dtdψt(x)=ut(ψt(x)∣x1)
为什么需要这些?
Flow Matching 的目标是训练一个神经网络去逼近 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(x∣x1) ),但直接定义全局的概率路径 ( p t ( x ) p_t(x) pt(x) ) 和向量场 ( u t ( x ) u_t(x) ut(x) ) 很难(因为数据分布 ( q(x) ) 未知)。通过引入条件路径 ( p t ( x ∣ x 1 ) p_t(x | x_1) pt(x∣x1) ),我们可以先为每个 ( x 1 x_1 x1 ) 设计一个简单路径,再通过平均得到全局效果。这种“分而治之”的思路是第 3 节条件流匹配的延续。
Theorem 3:流与向量场的数学保证
Theorem 3 是这一节的理论基石,它证明了流 ( ψ t ( x ) \psi_t(x) ψt(x) ) 和向量场 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(x∣x1) ) 的存在性和一致性。具体表述为:存在 ( ψ t ( x ) \psi_t(x) ψt(x) ) 和 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(x∣x1) ),使得:
- ( [ ψ t ] ∗ p ( x ) = p t ( x ∣ x 1 ) [\psi_t]_* p(x) = p_t(x | x_1) [ψt]∗p(x)=pt(x∣x1) )(流将初始分布推到条件分布, 这里的符号请参考前文提到的博客文章,那里是前置知识)。
- ( d d t ψ t ( x ) = u t ( ψ t ( x ) ∣ x 1 ) \frac{d}{dt} \psi_t(x) = u_t(\psi_t(x) | x_1) dtdψt(x)=ut(ψt(x)∣x1) )(向量场驱动流, 公式(13))。
证明
假设 ( p ( x ) = N ( 0 , I ) p(x) = \mathcal{N}(0, I) p(x)=N(0,I) ),( p t ( x ∣ x 1 ) = N ( x ∣ μ t ( x 1 ) , σ t ( x 1 ) 2 I ) p_t(x | x_1) = \mathcal{N}(x | \mu_t(x_1), \sigma_t(x_1)^2 I) pt(x∣x1)=N(x∣μt(x1),σt(x1)2I) ):
-
构造流:
ψ t ( x ) = σ t ( x 1 ) x + μ t ( x 1 ) \psi_t(x) = \sigma_t(x_1) x + \mu_t(x_1) ψt(x)=σt(x1)x+μt(x1)- ( t = 0 t = 0 t=0 ):( σ 0 ( x 1 ) = 1 \sigma_0(x_1) = 1 σ0(x1)=1 ),( μ 0 ( x 1 ) = 0 \mu_0(x_1) = 0 μ0(x1)=0 ),( ψ 0 ( x ) = x \psi_0(x) = x ψ0(x)=x ),匹配 ( p 0 ( x ∣ x 1 ) = N ( 0 , I ) p_0(x | x_1) = \mathcal{N}(0, I) p0(x∣x1)=N(0,I) )。
- ( t = 1 t = 1 t=1 ):( σ 1 ( x 1 ) ≈ 0 \sigma_1(x_1) \approx 0 σ1(x1)≈0 ),( μ 1 ( x 1 ) = x 1 \mu_1(x_1) = x_1 μ1(x1)=x1 ),接近 ( x 1 x_1 x1 )。
验证推前:
- ( x ∼ N ( 0 , I ) x \sim \mathcal{N}(0, I) x∼N(0,I) ),则 ( ψ t ( x ) ∼ N ( μ t ( x 1 ) , σ t ( x 1 ) 2 I \psi_t(x) \sim \mathcal{N}(\mu_t(x_1), \sigma_t(x_1)^2 I ψt(x)∼N(μt(x1),σt(x1)2I) )。
- 计算 (
[
ψ
t
]
∗
p
(
x
)
=
p
(
ψ
t
−
1
(
x
)
)
det
(
∂
ψ
t
−
1
∂
x
)
[\psi_t]_* p(x) = p(\psi_t^{-1}(x)) \det\left(\frac{\partial \psi_t^{-1}}{\partial x}\right)
[ψt]∗p(x)=p(ψt−1(x))det(∂x∂ψt−1) ):
ψ t − 1 ( x ) = x − μ t ( x 1 ) σ t ( x 1 ) , det ( ∂ ψ t − 1 ∂ x ) = 1 σ t ( x 1 ) d \psi_t^{-1}(x) = \frac{x - \mu_t(x_1)}{\sigma_t(x_1)}, \quad \det\left(\frac{\partial \psi_t^{-1}}{\partial x}\right) = \frac{1}{\sigma_t(x_1)^d} ψt−1(x)=σt(x1)x−μt(x1),det(∂x∂ψt−1)=σt(x1)d1
[ ψ t ] ∗ p ( x ) = N ( x ∣ μ t ( x 1 ) , σ t ( x 1 ) 2 I ) = p t ( x ∣ x 1 ) [\psi_t]_* p(x) = \mathcal{N}(x | \mu_t(x_1), \sigma_t(x_1)^2 I) = p_t(x | x_1) [ψt]∗p(x)=N(x∣μt(x1),σt(x1)2I)=pt(x∣x1)
-
计算向量场:
d d t ψ t ( x ) = σ ˙ t ( x 1 ) x + μ ˙ t ( x 1 ) \frac{d}{dt} \psi_t(x) = \dot{\sigma}_t(x_1) x + \dot{\mu}_t(x_1) dtdψt(x)=σ˙t(x1)x+μ˙t(x1)
设 ( z = ψ t ( x ) z = \psi_t(x) z=ψt(x) ):
u t ( z ∣ x 1 ) = σ ˙ t ( x 1 ) ⋅ z − μ t ( x 1 ) σ t ( x 1 ) + μ ˙ t ( x 1 ) u_t(z | x_1) = \dot{\sigma}_t(x_1) \cdot \frac{z - \mu_t(x_1)}{\sigma_t(x_1)} + \dot{\mu}_t(x_1) ut(z∣x1)=σ˙t(x1)⋅σt(x1)z−μt(x1)+μ˙t(x1)
代入 ( z = ψ t ( x ) z = \psi_t(x) z=ψt(x) ),满足公式 (13)。 -
验证一致性:
连续性方程 ( d d t p t ( x ∣ x 1 ) + div ( u t ( x ∣ x 1 ) p t ( x ∣ x 1 ) ) = 0 \frac{d}{dt} p_t(x | x_1) + \operatorname{div}(u_t(x | x_1) p_t(x | x_1)) = 0 dtdpt(x∣x1)+div(ut(x∣x1)pt(x∣x1))=0 ) 保证 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(x∣x1) ) 生成 ( p t ( x ∣ x 1 ) p_t(x | x_1) pt(x∣x1) )。具体计算较复杂(涉及高斯导数和散度),但可通过 4.1 节实例验证。
结论:( ψ t ( x ) \psi_t(x) ψt(x) ) 和 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(x∣x1) ) 满足要求。
Flow Matching 进阶:高斯条件概率路径的特殊实例
在《Flow Matching for Generative Modeling》第 4 节“Conditional Probability Paths and Vector Fields”中,我们了解了条件概率路径 ( p t ( x ∣ x 1 ) p_t(x | x_1) pt(x∣x1) ) 和向量场 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(x∣x1) ) 的基本概念。4.1 节“Special Instances of Gaussian Conditional Probability Paths”(高斯条件概率路径的特殊实例)则进一步展示了这一框架的灵活性,通过调整 ( μ t ( x 1 ) \mu_t(x_1) μt(x1) )(均值)和 ( σ t ( x 1 ) \sigma_t(x_1) σt(x1) )(标准差),我们可以设计出不同的生成路径。本文将深入探讨这一节,介绍两种典型实例——扩散条件路径和最优传输路径,带你理解它们的来龙去脉和应用场景。
4.1 高斯条件概率路径的特殊实例:灵活设计的艺术
Flow Matching 的条件概率路径 ( p t ( x ∣ x 1 ) = N ( x ∣ μ t ( x 1 ) , σ t ( x 1 ) 2 I ) p_t(x | x_1) = \mathcal{N}(x | \mu_t(x_1), \sigma_t(x_1)^2 I) pt(x∣x1)=N(x∣μt(x1),σt(x1)2I) ) 是一个通用框架,( μ t ( x 1 ) \mu_t(x_1) μt(x1) ) 和 ( σ t ( x 1 ) \sigma_t(x_1) σt(x1) ) 可以是任意满足边界条件的可微函数:
- ( t = 0 t = 0 t=0 ):起点通常是噪声分布(如 ( N ( 0 , I ) \mathcal{N}(0, I) N(0,I) ) 或接近噪声)。
- ( t = 1 t = 1 t=1 ):终点接近目标样本 ( x 1 x_1 x1 )(如 ( N ( x 1 , σ min 2 I ) \mathcal{N}(x_1, \sigma_{\text{min}}^2 I) N(x1,σmin2I) ))。
这种灵活性让 Flow Matching 既能复现传统扩散过程,又能探索全新的路径设计。4.1 节提供了两个例子:一个连接扩散模型,一个基于最优传输(Optimal Transport, OT),展示了这种设计的多样性。
示例 I:扩散条件向量场
背景与定义
扩散模型(Diffusion Models)通过从数据加噪到纯噪声(前向过程),再从噪声去噪到数据(反向过程),实现生成。Flow Matching 可以复现这种扩散路径的反向过程,定义特定的 ( μ t ( x 1 ) \mu_t(x_1) μt(x1) ) 和 ( σ t ( x 1 ) \sigma_t(x_1) σt(x1) )。
1. Variance Exploding (VE) 路径
- 路径:
p t ( x ∣ x 1 ) = N ( x ∣ x 1 , σ 1 − t 2 I ) p_t(x | x_1) = \mathcal{N}(x | x_1, \sigma_{1-t}^2 I) pt(x∣x1)=N(x∣x1,σ1−t2I)- ( σ t \sigma_t σt ) 是递增函数,( σ 0 = 0 \sigma_0 = 0 σ0=0 ),( σ 1 ≫ 1 \sigma_1 \gg 1 σ1≫1 )。
- ( t = 0 t = 0 t=0 ):( p 0 ( x ∣ x 1 ) = N ( x 1 , σ 1 2 I ) p_0(x | x_1) = \mathcal{N}(x_1, \sigma_1^2 I) p0(x∣x1)=N(x1,σ12I) )(高方差噪声)。
- ( t = 1 t = 1 t=1 ):( p 1 ( x ∣ x 1 ) = N ( x 1 , 0 ) p_1(x | x_1) = \mathcal{N}(x_1, 0) p1(x∣x1)=N(x1,0) )(集中在 ( x 1 x_1 x1 ))。
- 设置:
μ t ( x 1 ) = x 1 , σ t ( x 1 ) = σ 1 − t \mu_t(x_1) = x_1, \quad \sigma_t(x_1) = \sigma_{1-t} μt(x1)=x1,σt(x1)=σ1−t - 向量场(由 Theorem 3 的公式 15 (
u
t
(
x
∣
x
1
)
=
σ
˙
t
(
x
1
)
⋅
x
−
μ
t
(
x
1
)
σ
t
(
x
1
)
+
μ
˙
t
(
x
1
)
u_t(x | x_1) = \dot{\sigma}_t(x_1) \cdot \frac{x - \mu_t(x_1)}{\sigma_t(x_1)} + \dot{\mu}_t(x_1)
ut(x∣x1)=σ˙t(x1)⋅σt(x1)x−μt(x1)+μ˙t(x1) ) 导出):
σ ˙ t ( x 1 ) = d d t σ 1 − t = − σ 1 − t ′ , μ ˙ t ( x 1 ) = 0 \dot{\sigma}_t(x_1) = \frac{d}{dt} \sigma_{1-t} = -\sigma'_{1-t}, \quad \dot{\mu}_t(x_1) = 0 σ˙t(x1)=dtdσ1−t=−σ1−t′,μ˙t(x1)=0
u t ( x ∣ x 1 ) = − σ 1 − t ′ σ 1 − t ( x − x 1 ) u_t(x | x_1) = -\frac{\sigma'_{1-t}}{\sigma_{1-t}} (x - x_1) ut(x∣x1)=−σ1−tσ1−t′(x−x1)
2. Variance Preserving (VP) 路径
- 路径:
p t ( x ∣ x 1 ) = N ( x ∣ α 1 − t x 1 , ( 1 − α 1 − t 2 ) I ) p_t(x | x_1) = \mathcal{N}(x | \alpha_{1-t} x_1, (1 - \alpha_{1-t}^2) I) pt(x∣x1)=N(x∣α1−tx1,(1−α1−t2)I)- ( α t = e − 1 2 T ( t ) \alpha_t = e^{-\frac{1}{2} T(t)} αt=e−21T(t) ),( T ( t ) = ∫ 0 t β ( s ) d s T(t) = \int_0^t \beta(s) \, ds T(t)=∫0tβ(s)ds ),( β ( s ) \beta(s) β(s) ) 是噪声调度函数。
- ( t = 0 t = 0 t=0 ):( p 0 ( x ∣ x 1 ) = N ( x 1 , 1 ) p_0(x | x_1) = \mathcal{N}(x_1, 1) p0(x∣x1)=N(x1,1) )(接近噪声)。
- ( t = 1 t = 1 t=1 ):( p 1 ( x ∣ x 1 ) = N ( x 1 , 0 ) p_1(x | x_1) = \mathcal{N}(x_1, 0) p1(x∣x1)=N(x1,0) )(集中在 ( x 1 x_1 x1 ))。
- 设置:
μ t ( x 1 ) = α 1 − t x 1 , σ t ( x 1 ) = 1 − α 1 − t 2 \mu_t(x_1) = \alpha_{1-t} x_1, \quad \sigma_t(x_1) = \sqrt{1 - \alpha_{1-t}^2} μt(x1)=α1−tx1,σt(x1)=1−α1−t2 - 向量场:
μ ˙ t ( x 1 ) = − α 1 − t ′ x 1 , σ ˙ t ( x 1 ) = − α 1 − t α 1 − t ′ 1 − α 1 − t 2 \dot{\mu}_t(x_1) = -\alpha'_{1-t} x_1, \quad \dot{\sigma}_t(x_1) = \frac{-\alpha_{1-t} \alpha'_{1-t}}{\sqrt{1 - \alpha_{1-t}^2}} μ˙t(x1)=−α1−t′x1,σ˙t(x1)=1−α1−t2−α1−tα1−t′
u t ( x ∣ x 1 ) = − α 1 − t ′ 1 − α 1 − t 2 ( x − α 1 − t x 1 ) − α 1 − t ′ x 1 u_t(x | x_1) = -\frac{\alpha'_{1-t}}{\sqrt{1 - \alpha_{1-t}^2}} (x - \alpha_{1-t} x_1) - \alpha'_{1-t} x_1 ut(x∣x1)=−1−α1−t2α1−t′(x−α1−tx1)−α1−t′x1
化简后:
u t ( x ∣ x 1 ) = − T ′ ( 1 − t ) 2 ⋅ e − T ( 1 − t ) x − e − 1 2 T ( 1 − t ) x 1 1 − e − T ( 1 − t ) u_t(x | x_1) = -\frac{T'(1 - t)}{2} \cdot \frac{e^{-T(1-t)} x - e^{-\frac{1}{2} T(1-t)} x_1}{1 - e^{-T(1-t)}} ut(x∣x1)=−2T′(1−t)⋅1−e−T(1−t)e−T(1−t)x−e−21T(1−t)x1
联系与特点
- 与扩散模型的联系:这些路径与 Song 等人的概率流(Probability Flow)一致(见 Song et al., 2020b),但 Flow Matching 通过直接回归向量场 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(x∣x1) ),避免了扩散模型中复杂的得分匹配(Score Matching),实验中更稳定。
- 局限:扩散路径受限于其起源(扩散过程的解),( t = 0 t = 0 t=0 ) 时并非纯噪声(如 ( N ( 0 , I ) \mathcal{N}(0, I) N(0,I) )),而是近似,需要额外调整。
示例 II:最优传输条件向量场
背景与定义
既然 Flow Matching 不依赖扩散过程,我们可以直接设计更自然的路径,比如基于 Wasserstein-2 最优传输(Optimal Transport, OT)的路径,让均值和方差线性变化。
- 路径:
μ t ( x 1 ) = t x 1 , σ t ( x 1 ) = 1 − ( 1 − σ min ) t \mu_t(x_1) = t x_1, \quad \sigma_t(x_1) = 1 - (1 - \sigma_{\text{min}}) t μt(x1)=tx1,σt(x1)=1−(1−σmin)t- ( t = 0 t = 0 t=0 ):( p 0 ( x ∣ x 1 ) = N ( 0 , 1 ) p_0(x | x_1) = \mathcal{N}(0, 1) p0(x∣x1)=N(0,1) )(标准正态)。
- ( t = 1 t = 1 t=1 ):( p 1 ( x ∣ x 1 ) = N ( x 1 , σ min 2 ) p_1(x | x_1) = \mathcal{N}(x_1, \sigma_{\text{min}}^2) p1(x∣x1)=N(x1,σmin2) )(接近 ( x 1 x_1 x1 ))。
- 流:
ψ t ( x ) = ( 1 − ( 1 − σ min ) t ) x + t x 1 \psi_t(x) = (1 - (1 - \sigma_{\text{min}}) t) x + t x_1 ψt(x)=(1−(1−σmin)t)x+tx1 - 向量场:
μ ˙ t ( x 1 ) = x 1 , σ ˙ t ( x 1 ) = − ( 1 − σ min ) \dot{\mu}_t(x_1) = x_1, \quad \dot{\sigma}_t(x_1) = -(1 - \sigma_{\text{min}}) μ˙t(x1)=x1,σ˙t(x1)=−(1−σmin)
u t ( x ∣ x 1 ) = − ( 1 − σ min ) ⋅ x − t x 1 1 − ( 1 − σ min ) t + x 1 = x 1 − ( 1 − σ min ) x 1 − ( 1 − σ min ) t u_t(x | x_1) = -(1 - \sigma_{\text{min}}) \cdot \frac{x - t x_1}{1 - (1 - \sigma_{\text{min}}) t} + x_1 = \frac{x_1 - (1 - \sigma_{\text{min}}) x}{1 - (1 - \sigma_{\text{min}}) t} ut(x∣x1)=−(1−σmin)⋅1−(1−σmin)tx−tx1+x1=1−(1−σmin)tx1−(1−σmin)x - 损失:
L CFM ( θ ) = E t , q ( x 1 ) , p ( x 0 ) ∥ v t ( ψ t ( x 0 ) ) − x 1 − ( 1 − σ min ) x 0 1 − ( 1 − σ min ) t ∥ 2 \mathcal{L}_{\text{CFM}}(\theta) = \mathbb{E}_{t, q(x_1), p(x_0)} \left\| v_t(\psi_t(x_0)) - \frac{x_1 - (1 - \sigma_{\text{min}}) x_0}{1 - (1 - \sigma_{\text{min}}) t} \right\|^2 LCFM(θ)=Et,q(x1),p(x0) vt(ψt(x0))−1−(1−σmin)tx1−(1−σmin)x0 2
最优传输的意义
- OT 特性:(
ψ
t
(
x
)
\psi_t(x)
ψt(x) ) 是从 (
p
0
(
x
∣
x
1
)
=
N
(
0
,
1
)
p_0(x | x_1) = \mathcal{N}(0, 1)
p0(x∣x1)=N(0,1) ) 到 (
p
1
(
x
∣
x
1
)
=
N
(
x
1
,
σ
min
2
)
p_1(x | x_1) = \mathcal{N}(x_1, \sigma_{\text{min}}^2)
p1(x∣x1)=N(x1,σmin2) ) 的 OT 位移映射(Displacement Map),路径为:
p t = [ ( 1 − t ) id + t ψ ] ∗ p 0 p_t = [(1 - t) \text{id} + t \psi]_* p_0 pt=[(1−t)id+tψ]∗p0
粒子沿直线移动,速度恒定(见 McCann, 1997)。 - 优势:
- 相比扩散路径(可能“超调”后回溯),OT 路径始终直线前进,避免浪费。
- 定义在整个 ( t ∈ [ 0 , 1 ] t \in [0, 1] t∈[0,1] ),控制更直接。
总结
4.1 节展示了 Flow Matching 的强大灵活性:
- 扩散条件路径复现了传统扩散模型的反向过程(如 VE 和 VP),通过解析向量场提升训练效率,但受限于扩散起源。
- 最优传输路径跳出扩散框架,利用 OT 理论设计直线路径,简单高效且直观。
这些实例不仅连接了已有方法,还开辟了新思路。无论是借鑒扩散模型的经验,还是探索 OT 的优雅,Flow Matching 都为生成模型提供了丰富的可能性。希望这篇博客让你对高斯条件概率路径的设计有更直观的认识!如果想深入某个实例,欢迎继续讨论。
Flow Matching 进阶:高斯条件概率路径的特殊实例(补充版)
在《Flow Matching for Generative Modeling》第 4.1 节“Special Instances of Gaussian Conditional Probability Paths”(高斯条件概率路径的特殊实例)中,我们探讨了如何通过调整 ( μ t ( x 1 ) \mu_t(x_1) μt(x1) ) 和 ( σ t ( x 1 ) \sigma_t(x_1) σt(x1) ) 设计不同的生成路径。本文在之前的博客基础上,补充了对 Variance Exploding (VE) 路径和 Variance Preserving (VP) 路径的详细解释,以及对最优传输(OT)路径中 OT 位移映射公式的拆解,帮助你更全面地理解这些概念的来龙去脉。
示例 I:扩散条件向量场
什么是 Variance Exploding (VE) 和 Variance Preserving (VP) 路径?
扩散模型(Diffusion Models)通过前向过程加噪、反向过程去噪生成数据,其概率路径通常是高斯形式。VE 和 VP 是两种常见的路径设计,分别对应不同的噪声スケール策略。
Variance Exploding (VE) 路径
- 定义:
p t ( x ∣ x 1 ) = N ( x ∣ x 1 , σ 1 − t 2 I ) p_t(x | x_1) = \mathcal{N}(x | x_1, \sigma_{1-t}^2 I) pt(x∣x1)=N(x∣x1,σ1−t2I)- ( μ t ( x 1 ) = x 1 \mu_t(x_1) = x_1 μt(x1)=x1 ),均值固定。
- ( σ t ( x 1 ) = σ 1 − t \sigma_t(x_1) = \sigma_{1-t} σt(x1)=σ1−t ),( σ t \sigma_t σt ) 是递增函数,( σ 0 = 0 \sigma_0 = 0 σ0=0 ),( σ 1 ≫ 1 \sigma_1 \gg 1 σ1≫1 )(如 1000)。
- ( t = 0 t = 0 t=0 ):( p 0 ( x ∣ x 1 ) = N ( x 1 , σ 1 2 I ) p_0(x | x_1) = \mathcal{N}(x_1, \sigma_1^2 I) p0(x∣x1)=N(x1,σ12I) )(高方差噪声)。
- ( t = 1 t = 1 t=1 ):( p 1 ( x ∣ x 1 ) = N ( x 1 , 0 ) p_1(x | x_1) = \mathcal{N}(x_1, 0) p1(x∣x1)=N(x1,0) )(集中在 ( x 1 x_1 x1 ))。
- 向量场:
u t ( x ∣ x 1 ) = − σ 1 − t ′ σ 1 − t ( x − x 1 ) u_t(x | x_1) = -\frac{\sigma'_{1-t}}{\sigma_{1-t}} (x - x_1) ut(x∣x1)=−σ1−tσ1−t′(x−x1) - 特点:方差“爆炸式”增长,前向过程从数据加噪到接近纯噪声(( σ 1 2 \sigma_1^2 σ12 ) 很大),反向过程则逐渐减小方差。
Variance Preserving (VP) 路径
- 定义:
p t ( x ∣ x 1 ) = N ( x ∣ α 1 − t x 1 , ( 1 − α 1 − t 2 ) I ) p_t(x | x_1) = \mathcal{N}(x | \alpha_{1-t} x_1, (1 - \alpha_{1-t}^2) I) pt(x∣x1)=N(x∣α1−tx1,(1−α1−t2)I)- ( α t = e − 1 2 T ( t ) \alpha_t = e^{-\frac{1}{2} T(t)} αt=e−21T(t) ),( T ( t ) = ∫ 0 t β ( s ) d s T(t) = \int_0^t \beta(s) \, ds T(t)=∫0tβ(s)ds ),( β ( s ) \beta(s) β(s)) 是噪声调度函数。
- ( μ t ( x 1 ) = α 1 − t x 1 \mu_t(x_1) = \alpha_{1-t} x_1 μt(x1)=α1−tx1 ),均值随时间缩减。
- ( σ t ( x 1 ) = 1 − α 1 − t 2 \sigma_t(x_1) = \sqrt{1 - \alpha_{1-t}^2} σt(x1)=1−α1−t2 )。
- ( t = 0 t = 0 t=0 ):( p 0 ( x ∣ x 1 ) = N ( x 1 , 1 ) p_0(x | x_1) = \mathcal{N}(x_1, 1) p0(x∣x1)=N(x1,1) )(接近噪声)。
- ( t = 1 t = 1 t=1 ):( p 1 ( x ∣ x 1 ) = N ( x 1 , 0 ) p_1(x | x_1) = \mathcal{N}(x_1, 0) p1(x∣x1)=N(x1,0) )(集中在 ( x 1 x_1 x1 ))。
- 向量场:
u t ( x ∣ x 1 ) = − T ′ ( 1 − t ) 2 ⋅ e − T ( 1 − t ) x − e − 1 2 T ( 1 − t ) x 1 1 − e − T ( 1 − t ) u_t(x | x_1) = -\frac{T'(1 - t)}{2} \cdot \frac{e^{-T(1-t)} x - e^{-\frac{1}{2} T(1-t)} x_1}{1 - e^{-T(1-t)}} ut(x∣x1)=−2T′(1−t)⋅1−e−T(1−t)e−T(1−t)x−e−21T(1−t)x1 - 特点:方差保持在一定范围内(从 1 到 0),均值逐渐从 ( x 1 x_1 x1 ) 缩到 0(前向),反向则恢复。
VE 和 VP 的不同
- 方差行为:
- VE:方差从 0 增加到非常大(如 ( σ 1 2 ≫ 1 \sigma_1^2 \gg 1 σ12≫1 )),强调“爆炸”效应。
- VP:方差从 0 到 1 再回到 0,保持相对稳定,避免过大噪声。
- 均值行为:
- VE:均值始终是 ( x 1 x_1 x1 ),不变。
- VP:均值从 ( x 1 x_1 x1 ) 缩到接近 0(前向),反向恢复到 ( x 1 x_1 x1 )。
- 适用场景:
- VE:适合需要强噪声的任务(如高维数据),但计算复杂度高。
- VP:更平衡,常用在图像生成(如 DDPM),稳定性更好。
为什么提出这两种路径?
- VE 的起源:Sohl-Dickstein 等(2015)提出扩散模型时,VE 路径通过大方差模拟纯噪声,便于理论分析和采样。
- VP 的改进:Ho 等(2020)发现 VE 的极端方差会导致训练不稳定,提出 VP 路径,通过控制 ( \beta(t) ) 保持方差适中,提升生成质量。Song 等(2020b)进一步将 VP 路径用于概率流,奠定了其主流地位。
- 动机:两者都旨在通过高斯路径简化扩散过程的闭式解,但 VE 更理论化,VP 更实用化。
示例 II:最优传输条件向量场
定义与路径
- 设置:
μ t ( x 1 ) = t x 1 , σ t ( x 1 ) = 1 − ( 1 − σ min ) t \mu_t(x_1) = t x_1, \quad \sigma_t(x_1) = 1 - (1 - \sigma_{\text{min}}) t μt(x1)=tx1,σt(x1)=1−(1−σmin)t- ( t = 0 t = 0 t=0 ):( p 0 ( x ∣ x 1 ) = N ( 0 , 1 ) p_0(x | x_1) = \mathcal{N}(0, 1) p0(x∣x1)=N(0,1) )。
- ( t = 1 t = 1 t=1 ):( p 1 ( x ∣ x 1 ) = N ( x 1 , σ min 2 ) p_1(x | x_1) = \mathcal{N}(x_1, \sigma_{\text{min}}^2) p1(x∣x1)=N(x1,σmin2) )(( σ min \sigma_{\text{min}} σmin ) 通常很小)。
- 流:
ψ t ( x ) = ( 1 − ( 1 − σ min ) t ) x + t x 1 \psi_t(x) = (1 - (1 - \sigma_{\text{min}}) t) x + t x_1 ψt(x)=(1−(1−σmin)t)x+tx1 - 向量场:
u t ( x ∣ x 1 ) = x 1 − ( 1 − σ min ) x 1 − ( 1 − σ min ) t u_t(x | x_1) = \frac{x_1 - (1 - \sigma_{\text{min}}) x}{1 - (1 - \sigma_{\text{min}}) t} ut(x∣x1)=1−(1−σmin)tx1−(1−σmin)x
OT 特性详解
- OT 特性:
p t = [ ( 1 − t ) id + t ψ ] ∗ p 0 p_t = [(1 - t) \text{id} + t \psi]_* p_0 pt=[(1−t)id+tψ]∗p0- ( ψ t ( x ) \psi_t(x) ψt(x) ) 是从 ( p 0 ( x ∣ x 1 ) = N ( 0 , 1 ) p_0(x | x_1) = \mathcal{N}(0, 1) p0(x∣x1)=N(0,1) ) 到 ( p 1 ( x ∣ x 1 ) = N ( x 1 , σ min 2 ) p_1(x | x_1) = \mathcal{N}(x_1, \sigma_{\text{min}}^2) p1(x∣x1)=N(x1,σmin2) ) 的 OT 位移映射(Displacement Map)。
- 粒子沿直线移动,速度恒定(见 McCann, 1997)。
公式拆解
-
( id \text{id} id ) 是什么?
- ( id \text{id} id ) 是恒等映射(identity map),即 ( id ( x ) = x \text{id}(x) = x id(x)=x )。它表示不做任何变换,直接输出输入本身。
- 在这里,( id \text{id} id ) 代表路径的起点 ( p 0 p_0 p0 ) 的位置。
-
( ( 1 − t ) id + t ψ (1 - t) \text{id} + t \psi (1−t)id+tψ ) 是什么?
- 这是一个线性插值函数,称为 OT 位移映射:
( 1 − t ) id ( x ) + t ψ ( x ) = ( 1 − t ) x + t ψ ( x ) (1 - t) \text{id}(x) + t \psi(x) = (1 - t) x + t \psi(x) (1−t)id(x)+tψ(x)=(1−t)x+tψ(x) - 代入 (
ψ
(
x
)
=
x
1
\psi(x) = x_1
ψ(x)=x1 )(假设 (
σ
min
=
0
\sigma_{\text{min}} = 0
σmin=0 ) 简化理解):
( 1 − t ) x + t x 1 (1 - t) x + t x_1 (1−t)x+tx1 - ( t = 0 t = 0 t=0 ):输出 ( x x x )(起点)。
- ( t = 1 t = 1 t=1 ):输出 ( x 1 x_1 x1 )(终点)。
- 随着 ( t t t ) 从 0 到 1,粒子从 ( x x x ) 直线移动到 ( x 1 x_1 x1 )。
- 这是一个线性插值函数,称为 OT 位移映射:
-
( [ ( 1 − t ) id + t ψ ] ∗ p 0 [(1 - t) \text{id} + t \psi]_* p_0 [(1−t)id+tψ]∗p0 ) 是什么?
- ( [ ⋅ ] ∗ [\cdot]_* [⋅]∗ ) 表示推前(push-forward)操作,将分布 ( p 0 p_0 p0 ) 通过映射变换到新分布。
- ( p 0 ( x ∣ x 1 ) = N ( 0 , 1 ) p_0(x | x_1) = \mathcal{N}(0, 1) p0(x∣x1)=N(0,1) ) 是初始分布。
- 令 (
ϕ
t
(
x
)
=
(
1
−
t
)
x
+
t
ψ
(
x
)
=
(
1
−
(
1
−
σ
min
)
t
)
x
+
t
x
1
\phi_t(x) = (1 - t) x + t \psi(x) = (1 - (1 - \sigma_{\text{min}}) t) x + t x_1
ϕt(x)=(1−t)x+tψ(x)=(1−(1−σmin)t)x+tx1 ):
- ( x ∼ N ( 0 , 1 ) x \sim \mathcal{N}(0, 1) x∼N(0,1) )。
- ( ϕ t ( x ) ∼ N ( t x 1 , ( 1 − ( 1 − σ min ) t ) 2 I ) \phi_t(x) \sim \mathcal{N}(t x_1, (1 - (1 - \sigma_{\text{min}}) t)^2 I) ϕt(x)∼N(tx1,(1−(1−σmin)t)2I) )(均值和方差通过线性变换计算)。
- 因此:
p t ( x ) = [ ( 1 − t ) id + t ψ ] ∗ p 0 = N ( x ∣ t x 1 , ( 1 − ( 1 − σ min ) t ) 2 I ) p_t(x) = [(1 - t) \text{id} + t \psi]_* p_0 = \mathcal{N}(x | t x_1, (1 - (1 - \sigma_{\text{min}}) t)^2 I) pt(x)=[(1−t)id+tψ]∗p0=N(x∣tx1,(1−(1−σmin)t)2I) - 这与定义一致,验证了 OT 路径的正确性。
直观理解
- 直线移动:粒子从 ( x ∼ N ( 0 , 1 ) x \sim \mathcal{N}(0, 1) x∼N(0,1) ) 到 ( x 1 x_1 x1 ) 的路径是直线,速度 ( x 1 − x 1 \frac{x_1 - x}{1} 1x1−x ) 恒定。
- OT 的优雅:OT 路径是最优的(Wasserstein-2 距离最小),避免了扩散路径可能出现的“超调”或迂回。
总结
4.1 节通过 VE 和 VP 路径展示了 Flow Matching 与扩散模型的联系:
- VE:方差爆炸,适合强噪声场景。
- VP:方差平稳,更实用且稳定。
而 OT 路径则跳出扩散框架:
- 通过 ( p t = [ ( 1 − t ) id + t ψ ] ∗ p 0 p_t = [(1 - t) \text{id} + t \psi]_* p_0 pt=[(1−t)id+tψ]∗p0 ),利用 OT 位移映射实现直线高效移动,简单且优雅。
这些设计不仅复现了传统方法,还拓展了新可能。希望这篇补充博客让你对 VE、VP 和 OT 路径的细节一目了然!
Flow Matching 的训练和推理代码示例
下面将提供一个基于 Flow Matching 的训练和推理代码示例,使用最优传输(Optimal Transport, OT)路径。我们将聚焦于如何计算流 ( ψ t ( x ) \psi_t(x) ψt(x) ) 和向量场 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(x∣x1) ),并展示如何为每个目标样本 ( x 1 x_1 x1 ) 设计专属路径,让噪声分布 ( N ( 0 , I ) \mathcal{N}(0, I) N(0,I) ) 逐步变成目标分布 ( q ( x 1 ) q(x_1) q(x1) )。由于直接使用 ImageNet 需要大量预处理和计算资源,我将以 MNIST 数据集为例(一个简单的图片数据集),并在代码中注释说明如何扩展到 ImageNet。
代码设计思路
- 数据集:使用 MNIST(28x28 灰度图像,展平为 784 维向量),模拟 ( q ( x 1 ) q(x_1) q(x1) )。扩展到 ImageNet 时,只需替换数据加载器并调整网络结构。
- OT 路径:
- ( μ t ( x 1 ) = t x 1 \mu_t(x_1) = t x_1 μt(x1)=tx1 ):均值线性插值。
- ( σ t ( x 1 ) = 1 − ( 1 − σ min ) t \sigma_t(x_1) = 1 - (1 - \sigma_{\text{min}}) t σt(x1)=1−(1−σmin)t ):标准差从 1 线性减小到 ( σ min \sigma_{\text{min}} σmin )(如 0.01)。
- 流:( ψ t ( x ) = ( 1 − ( 1 − σ min ) t ) x + t x 1 \psi_t(x) = (1 - (1 - \sigma_{\text{min}}) t) x + t x_1 ψt(x)=(1−(1−σmin)t)x+tx1 )。
- 向量场:( u t ( x ∣ x 1 ) = x 1 − ( 1 − σ min ) x 1 − ( 1 − σ min ) t u_t(x | x_1) = \frac{x_1 - (1 - \sigma_{\text{min}}) x}{1 - (1 - \sigma_{\text{min}}) t} ut(x∣x1)=1−(1−σmin)tx1−(1−σmin)x )。
- 训练目标:通过条件流匹配损失 ( L CFM = E t , q ( x 1 ) , p ( x 0 ) ∥ v θ ( ψ t ( x 0 ) , t ) − u t ( ψ t ( x 0 ) ∣ x 1 ) ∥ 2 \mathcal{L}_{\text{CFM}} = \mathbb{E}_{t, q(x_1), p(x_0)} \| v_\theta(\psi_t(x_0), t) - u_t(\psi_t(x_0) | x_1) \|^2 LCFM=Et,q(x1),p(x0)∥vθ(ψt(x0),t)−ut(ψt(x0)∣x1)∥2 ) 训练神经网络 ( v θ v_\theta vθ ) 逼近 ( u t u_t ut )。
- 推理:从噪声 ( x 0 ∼ N ( 0 , I ) x_0 \sim \mathcal{N}(0, I) x0∼N(0,I) ) 开始,用学到的 ( v θ v_\theta vθ ) 解 ODE 生成样本。
训练代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
# 设备设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 数据加载:MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# 扩展到 ImageNet:替换为 ImageNet 数据加载器,例如:
# from torchvision.datasets import ImageNet
# train_dataset = ImageNet(root='./imagenet', split='train', transform=transform)
# 输入维度调整为 3x224x224(ImageNet)并展平为 3*224*224=150528
# 超参数
dim = 784 # MNIST 展平维度 (28*28)
sigma_min = 0.01 # OT 路径的最小方差
num_epochs = 10
lr = 1e-3
# 神经网络:预测向量场 v_\theta(x, t)
class VectorFieldNet(nn.Module):
def __init__(self, input_dim):
super(VectorFieldNet, self).__init__()
self.net = nn.Sequential(
nn.Linear(input_dim + 1, 512), # 输入 x 和 t
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, input_dim) # 输出向量场
)
def forward(self, x, t):
t = t.view(-1, 1) # [batch_size, 1]
xt = torch.cat([x, t], dim=1) # 拼接 x 和 t
return self.net(xt)
# OT 路径相关函数
def psi_t(x, x1, t, sigma_min=0.01):
"""计算流 \psi_t(x)"""
scale = 1 - (1 - sigma_min) * t # \sigma_t(x_1)
return scale * x + t * x1 # (1 - (1 - \sigma_min) t) x + t x_1
def u_t(x, x1, t, sigma_min=0.01):
"""计算目标向量场 u_t(x | x_1)"""
scale = 1 - (1 - sigma_min) * t
return (x1 - (1 - sigma_min) * x) / scale # (x_1 - (1 - \sigma_min) x) / (1 - (1 - \sigma_min) t)
# 初始化模型和优化器
model = VectorFieldNet(dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
# 训练循环
for epoch in range(num_epochs):
model.train()
total_loss = 0
for batch_idx, (x1, _) in enumerate(train_loader):
x1 = x1.view(-1, dim).to(device) # x_1 来自 q(x_1),[batch_size, 784]
batch_size = x1.size(0)
# 采样 t 和 x_0
t = torch.rand(batch_size, device=device) # t ~ U[0, 1]
x0 = torch.randn(batch_size, dim, device=device) # x_0 ~ N(0, I)
# 计算 \psi_t(x_0)
xt = psi_t(x0, x1, t, sigma_min)
# 计算目标向量场 u_t
target_u = u_t(xt, x1, t, sigma_min)
# 预测向量场 v_\theta
pred_u = model(xt, t)
# 计算损失
loss = torch.mean((pred_u - target_u) ** 2) # \mathcal{L}_{\text{CFM}}
total_loss += loss.item()
# 优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
avg_loss = total_loss / len(train_loader)
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.6f}")
# 保存模型
torch.save(model.state_dict(), "flow_matching_ot.pth")
代码说明
-
数据加载:
- MNIST 图像展平为 784 维向量,归一化到 [-1, 1]。
- 扩展到 ImageNet:只需替换数据集为 ImageNet,维度调整为 150528(3x224x224),并可能需要更复杂的网络结构(如 CNN)。
-
流 ( ψ t ( x ) \psi_t(x) ψt(x) ):
- 函数
psi_t
计算 ( ψ t ( x ) = ( 1 − ( 1 − σ min ) t ) x + t x 1 \psi_t(x) = (1 - (1 - \sigma_{\text{min}}) t) x + t x_1 ψt(x)=(1−(1−σmin)t)x+tx1 )。 - ( x x x ) 是初始噪声 ( x 0 ∼ N ( 0 , I ) x_0 \sim \mathcal{N}(0, I) x0∼N(0,I) ),( x 1 x_1 x1 ) 是目标样本,( t t t ) 控制插值进度。
- 函数
-
向量场 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(x∣x1) ):
- 函数
u_t
计算 ( u t ( x ∣ x 1 ) = x 1 − ( 1 − σ min ) x 1 − ( 1 − σ min ) t u_t(x | x_1) = \frac{x_1 - (1 - \sigma_{\text{min}}) x}{1 - (1 - \sigma_{\text{min}}) t} ut(x∣x1)=1−(1−σmin)tx1−(1−σmin)x )。 - ( x = ψ t ( x 0 ) x = \psi_t(x_0) x=ψt(x0) ) 是当前时间 ( t t t ) 的状态,( u t u_t ut ) 是目标方向。
- 函数
-
训练:
- 采样 ( t ∼ U [ 0 , 1 ] t \sim U[0, 1] t∼U[0,1] ),( x 0 ∼ N ( 0 , I ) x_0 \sim \mathcal{N}(0, I) x0∼N(0,I) ),计算 ( ψ t ( x 0 ) \psi_t(x_0) ψt(x0)) 和 ( u t u_t ut )。
- 神经网络 ( v θ v_\theta vθ ) 预测 ( u t u_t ut ),通过均方误差优化。
推理代码
import torch
from torchdiffeq import odeint # 需要安装 torchdiffeq: pip install torchdiffeq
# 加载模型
model = VectorFieldNet(dim).to(device)
model.load_state_dict(torch.load("flow_matching_ot.pth"))
model.eval()
# ODE 解算器封装向量场
def vector_field(t, x):
"""将 v_\theta 封装为 ODE 函数"""
t_tensor = torch.ones(x.size(0), device=device) * t # 扩展 t 到 batch_size
with torch.no_grad():
return model(x, t_tensor)
# 推理:从噪声生成样本
num_samples = 16
x0 = torch.randn(num_samples, dim, device=device) # x_0 ~ N(0, I)
t_span = torch.linspace(0, 1, 100, device=device) # 时间步长
# 解 ODE
x_t = odeint(vector_field, x0, t_span, method='rk4') # [num_steps, batch_size, dim]
x1_pred = x_t[-1] # 取 t=1 的结果
# 可视化(仅适用于 MNIST)
import matplotlib.pyplot as plt
x1_pred = x1_pred.view(-1, 28, 28).cpu().numpy() # 重塑为图像
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for i, ax in enumerate(axes.flat):
ax.imshow(x1_pred[i], cmap='gray')
ax.axis('off')
plt.show()
代码说明
-
推理过程:
- 从 ( x 0 ∼ N ( 0 , I ) x_0 \sim \mathcal{N}(0, I) x0∼N(0,I) ) 开始,使用学到的 ( v θ v_\theta vθ ) 作为向量场。
- 通过 ODE 求解器(这里用 Runge-Kutta 4 阶方法)积分从 ( t = 0 t = 0 t=0 ) 到 ( t = 1 t = 1 t=1 ),得到 ( x 1 x_1 x1 )。
-
流 ( ψ t ( x ) \psi_t(x) ψt(x) ) 的计算:
- 在训练中,( ψ t ( x ) \psi_t(x) ψt(x) ) 用于生成中间状态 ( x t x_t xt )。
- 在推理中,( ψ t ( x ) \psi_t(x) ψt(x) ) 被 ( v θ v_\theta vθ ) 隐式逼近,ODE 解算器直接输出路径。
-
扩展到 ImageNet:
- 输入维度改为 150528,网络可能需要用 CNN(如 U-Net)。
- 可视化需调整为彩色图像(3x224x224)。
如何为每个 ( x 1 x_1 x1 ) 设计专属路径?
实现原理
-
条件化设计:
- 对于每个 ( x 1 ∼ q ( x 1 ) x_1 \sim q(x_1) x1∼q(x1) )(如 MNIST 的图像),初始噪声 ( x 0 ∼ N ( 0 , I ) x_0 \sim \mathcal{N}(0, I) x0∼N(0,I) ) 是通用的。
- OT 路径通过 (
ψ
t
(
x
0
)
=
(
1
−
(
1
−
σ
min
)
t
)
x
0
+
t
x
1
\psi_t(x_0) = (1 - (1 - \sigma_{\text{min}}) t) x_0 + t x_1
ψt(x0)=(1−(1−σmin)t)x0+tx1 ) 为每个 (
x
1
x_1
x1 ) 定义一条从 (
x
0
x_0
x0 ) 到 (
x
1
x_1
x1 ) 的专属路径:
- ( t = 0 t = 0 t=0 ):( x 0 x_0 x0 )(噪声)。
- ( t = 1 t = 1 t=1 ):接近 ( x 1 x_1 x1 )(目标)。
-
向量场的作用:
- ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(x∣x1) ) 是解析定义的,依赖于当前 ( x x x ) 和目标 ( x 1 x_1 x1 ),指导 ( x x x ) 沿路径移动。
- 神经网络 ( v θ v_\theta vθ ) 学习逼近 ( u t u_t ut ),使其适用于任意 ( x 1 x_1 x1 )。
-
训练中的实现:
- 每次迭代从数据加载器采样 ( x 1 x_1 x1 ),随机生成 ( t t t ) 和 ( x 0 x_0 x0 )。
- 计算 ( ψ t ( x 0 ) \psi_t(x_0) ψt(x0) ) 和 ( u t ( ψ t ( x 0 ) ∣ x 1 ) u_t(\psi_t(x_0) | x_1) ut(ψt(x0)∣x1) ),让 ( v θ v_\theta vθ ) 匹配这条专属路径的方向。
-
推理中的实现:
- 从 ( x 0 x_0 x0 ) 开始,( v θ v_\theta vθ ) 驱动 ODE,隐式生成从噪声到数据的路径。
- 虽然训练时路径针对每个 ( x 1 x_1 x1 ),推理时无需指定 ( x 1 x_1 x1 ),( v θ v_\theta vθ ) 已学到通用映射。
直观理解
- 专属路径:训练时,( x 1 x_1 x1 ) 作为条件输入,( ψ t \psi_t ψt ) 和 ( u t u_t ut) 明确指向 ( x 1 x_1 x1 ),像为每个目标画一条“导航线”。
- 逐步变成目标:( σ t \sigma_t σt ) 从 1 减到 ( σ min \sigma_{\text{min}} σmin ),( μ t \mu_t μt ) 从 0 移到 ( x 1 x_1 x1 ),噪声逐渐“聚焦”到 ( x 1 x_1 x1 )。
总结
这段代码展示了 Flow Matching 使用 OT 路径的完整流程:
- 训练:通过 ( ψ t ( x ) \psi_t(x) ψt(x) ) 和 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(x∣x1) ) 为每个 ( x 1 x_1 x1 ) 设计路径,优化 ( v θ v_\theta vθ )。
- 推理:从噪声解 ODE,直接生成样本。
相比扩散模型,Flow Matching 跳过了前向加噪,直接从噪声到数据的路径更高效。扩展到 ImageNet 只需调整数据和网络规模,核心逻辑不变。希望这篇代码和解释让你明白如何将理论落地!
后记
2025年4月7日22点46分于上海,在grok 3大模型辅助下完成。