深入解析 Flow Matching(二):从条件概率路径与向量场到条件流匹配

深入解析 Flow Matching(二):从条件概率路径与向量场到条件流匹配

在生成模型的研究中,Flow Matching(流匹配)是一种高效且优雅的训练方法,基于连续归一化流(Continuous Normalizing Flows, CNFs)。论文《Flow Matching for Generative Modeling》第 4 节“Conditional Probability Paths and Vector Fields”(条件概率路径与向量场)是理解这一方法的关键部分,它详细描述了如何通过条件概率路径和向量场构造生成过程。本文将以通俗的语言介绍这一节的内容,包括 4.1 节“Special Instances of Gaussian Conditional Probability Paths”(高斯条件概率路径的特殊实例),并给出 Theorem 3 的证明,帮助你快速把握来龙去脉。

请先参考笔者的第一篇博客:深入解析 Flow Matching:从条件概率路径与向量场到条件流匹配

paper link: https://arxiv.org/pdf/2210.02747


4 Conditional Probability Paths and Vector Fields 是什么?

想象一下,你想从一团随机噪声(比如标准正态分布)出发,逐步“雕刻”出真实数据的形状(比如图片或文本的分布)。Flow Matching 的核心就是通过一个平滑的“流动”过程实现这一点。第 4 节聚焦于如何为每个目标样本 ( x 1 x_1 x1 )(来自数据分布 ( q ( x 1 ) q(x_1) q(x1) ))设计一个专属的路径,让噪声分布逐步变成目标分布。

基本概念

  • 条件概率 ( p t ( x ∣ x 1 ) p_t(x | x_1) pt(xx1) ):这是一个概率密度路径,描述在给定目标 ( x 1 x_1 x1 ) 的情况下,( x x x ) 在时间 ( t t t ) 的分布。

    • ( t = 0 t = 0 t=0 ):起点是噪声,比如 ( p 0 ( x ∣ x 1 ) = N ( 0 , I ) p_0(x | x_1) = \mathcal{N}(0, I) p0(xx1)=N(0,I) )(标准正态分布)。
    • ( t = 1 t = 1 t=1 ):终点接近 ( x 1 x_1 x1 ),比如 ( p 1 ( x ∣ x 1 ) p_1(x | x_1) p1(xx1) ) 几乎集中在 ( x 1 x_1 x1 ) 附近。
  • 流 ( ψ t ( x ) \psi_t(x) ψt(x) ):一个变换函数,像流水线一样,把初始的 ( x x x)(从噪声中采样)推到时间 ( t t t ) 的位置。它通常是一个仿射变换,比如:
    ψ t ( x ) = σ t ( x 1 ) x + μ t ( x 1 ) \psi_t(x) = \sigma_t(x_1) x + \mu_t(x_1) ψt(x)=σt(x1)x+μt(x1)
    其中 ( σ t ( x 1 ) \sigma_t(x_1) σt(x1) ) 控制缩放,( μ t ( x 1 ) \mu_t(x_1) μt(x1) ) 控制平移。

  • 向量场 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(xx1) ):流的“动力”,告诉 ( ψ t ( x ) \psi_t(x) ψt(x) ) 每时每刻该往哪儿走,满足:
    d d t ψ t ( x ) = u t ( ψ t ( x ) ∣ x 1 ) \frac{d}{dt} \psi_t(x) = u_t(\psi_t(x) | x_1) dtdψt(x)=ut(ψt(x)x1)

为什么需要这些?

Flow Matching 的目标是训练一个神经网络去逼近 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(xx1) ),但直接定义全局的概率路径 ( p t ( x ) p_t(x) pt(x) ) 和向量场 ( u t ( x ) u_t(x) ut(x) ) 很难(因为数据分布 ( q(x) ) 未知)。通过引入条件路径 ( p t ( x ∣ x 1 ) p_t(x | x_1) pt(xx1) ),我们可以先为每个 ( x 1 x_1 x1 ) 设计一个简单路径,再通过平均得到全局效果。这种“分而治之”的思路是第 3 节条件流匹配的延续。


Theorem 3:流与向量场的数学保证

Theorem 3 是这一节的理论基石,它证明了流 ( ψ t ( x ) \psi_t(x) ψt(x) ) 和向量场 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(xx1) ) 的存在性和一致性。具体表述为:存在 ( ψ t ( x ) \psi_t(x) ψt(x) ) 和 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(xx1) ),使得:

  1. ( [ ψ t ] ∗ p ( x ) = p t ( x ∣ x 1 ) [\psi_t]_* p(x) = p_t(x | x_1) [ψt]p(x)=pt(xx1) )(流将初始分布推到条件分布, 这里的符号请参考前文提到的博客文章,那里是前置知识)。
  2. ( d d t ψ t ( x ) = u t ( ψ t ( x ) ∣ x 1 ) \frac{d}{dt} \psi_t(x) = u_t(\psi_t(x) | x_1) dtdψt(x)=ut(ψt(x)x1) )(向量场驱动流, 公式(13))。

证明

假设 ( p ( x ) = N ( 0 , I ) p(x) = \mathcal{N}(0, I) p(x)=N(0,I) ),( p t ( x ∣ x 1 ) = N ( x ∣ μ t ( x 1 ) , σ t ( x 1 ) 2 I ) p_t(x | x_1) = \mathcal{N}(x | \mu_t(x_1), \sigma_t(x_1)^2 I) pt(xx1)=N(xμt(x1),σt(x1)2I) ):

  1. 构造流
    ψ t ( x ) = σ t ( x 1 ) x + μ t ( x 1 ) \psi_t(x) = \sigma_t(x_1) x + \mu_t(x_1) ψt(x)=σt(x1)x+μt(x1)

    • ( t = 0 t = 0 t=0 ):( σ 0 ( x 1 ) = 1 \sigma_0(x_1) = 1 σ0(x1)=1 ),( μ 0 ( x 1 ) = 0 \mu_0(x_1) = 0 μ0(x1)=0 ),( ψ 0 ( x ) = x \psi_0(x) = x ψ0(x)=x ),匹配 ( p 0 ( x ∣ x 1 ) = N ( 0 , I ) p_0(x | x_1) = \mathcal{N}(0, I) p0(xx1)=N(0,I) )。
    • ( t = 1 t = 1 t=1 ):( σ 1 ( x 1 ) ≈ 0 \sigma_1(x_1) \approx 0 σ1(x1)0 ),( μ 1 ( x 1 ) = x 1 \mu_1(x_1) = x_1 μ1(x1)=x1 ),接近 ( x 1 x_1 x1 )。

    验证推前:

    • ( x ∼ N ( 0 , I ) x \sim \mathcal{N}(0, I) xN(0,I) ),则 ( ψ t ( x ) ∼ N ( μ t ( x 1 ) , σ t ( x 1 ) 2 I \psi_t(x) \sim \mathcal{N}(\mu_t(x_1), \sigma_t(x_1)^2 I ψt(x)N(μt(x1),σt(x1)2I) )。
    • 计算 ( [ ψ t ] ∗ p ( x ) = p ( ψ t − 1 ( x ) ) det ⁡ ( ∂ ψ t − 1 ∂ x ) [\psi_t]_* p(x) = p(\psi_t^{-1}(x)) \det\left(\frac{\partial \psi_t^{-1}}{\partial x}\right) [ψt]p(x)=p(ψt1(x))det(xψt1) ):
      ψ t − 1 ( x ) = x − μ t ( x 1 ) σ t ( x 1 ) , det ⁡ ( ∂ ψ t − 1 ∂ x ) = 1 σ t ( x 1 ) d \psi_t^{-1}(x) = \frac{x - \mu_t(x_1)}{\sigma_t(x_1)}, \quad \det\left(\frac{\partial \psi_t^{-1}}{\partial x}\right) = \frac{1}{\sigma_t(x_1)^d} ψt1(x)=σt(x1)xμt(x1),det(xψt1)=σt(x1)d1
      [ ψ t ] ∗ p ( x ) = N ( x ∣ μ t ( x 1 ) , σ t ( x 1 ) 2 I ) = p t ( x ∣ x 1 ) [\psi_t]_* p(x) = \mathcal{N}(x | \mu_t(x_1), \sigma_t(x_1)^2 I) = p_t(x | x_1) [ψt]p(x)=N(xμt(x1),σt(x1)2I)=pt(xx1)
  2. 计算向量场
    d d t ψ t ( x ) = σ ˙ t ( x 1 ) x + μ ˙ t ( x 1 ) \frac{d}{dt} \psi_t(x) = \dot{\sigma}_t(x_1) x + \dot{\mu}_t(x_1) dtdψt(x)=σ˙t(x1)x+μ˙t(x1)
    设 ( z = ψ t ( x ) z = \psi_t(x) z=ψt(x) ):
    u t ( z ∣ x 1 ) = σ ˙ t ( x 1 ) ⋅ z − μ t ( x 1 ) σ t ( x 1 ) + μ ˙ t ( x 1 ) u_t(z | x_1) = \dot{\sigma}_t(x_1) \cdot \frac{z - \mu_t(x_1)}{\sigma_t(x_1)} + \dot{\mu}_t(x_1) ut(zx1)=σ˙t(x1)σt(x1)zμt(x1)+μ˙t(x1)
    代入 ( z = ψ t ( x ) z = \psi_t(x) z=ψt(x) ),满足公式 (13)。

  3. 验证一致性
    连续性方程 ( d d t p t ( x ∣ x 1 ) + div ⁡ ( u t ( x ∣ x 1 ) p t ( x ∣ x 1 ) ) = 0 \frac{d}{dt} p_t(x | x_1) + \operatorname{div}(u_t(x | x_1) p_t(x | x_1)) = 0 dtdpt(xx1)+div(ut(xx1)pt(xx1))=0 ) 保证 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(xx1) ) 生成 ( p t ( x ∣ x 1 ) p_t(x | x_1) pt(xx1) )。具体计算较复杂(涉及高斯导数和散度),但可通过 4.1 节实例验证。

结论:( ψ t ( x ) \psi_t(x) ψt(x) ) 和 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(xx1) ) 满足要求。


Flow Matching 进阶:高斯条件概率路径的特殊实例

在《Flow Matching for Generative Modeling》第 4 节“Conditional Probability Paths and Vector Fields”中,我们了解了条件概率路径 ( p t ( x ∣ x 1 ) p_t(x | x_1) pt(xx1) ) 和向量场 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(xx1) ) 的基本概念。4.1 节“Special Instances of Gaussian Conditional Probability Paths”(高斯条件概率路径的特殊实例)则进一步展示了这一框架的灵活性,通过调整 ( μ t ( x 1 ) \mu_t(x_1) μt(x1) )(均值)和 ( σ t ( x 1 ) \sigma_t(x_1) σt(x1) )(标准差),我们可以设计出不同的生成路径。本文将深入探讨这一节,介绍两种典型实例——扩散条件路径和最优传输路径,带你理解它们的来龙去脉和应用场景。


4.1 高斯条件概率路径的特殊实例:灵活设计的艺术

Flow Matching 的条件概率路径 ( p t ( x ∣ x 1 ) = N ( x ∣ μ t ( x 1 ) , σ t ( x 1 ) 2 I ) p_t(x | x_1) = \mathcal{N}(x | \mu_t(x_1), \sigma_t(x_1)^2 I) pt(xx1)=N(xμt(x1),σt(x1)2I) ) 是一个通用框架,( μ t ( x 1 ) \mu_t(x_1) μt(x1) ) 和 ( σ t ( x 1 ) \sigma_t(x_1) σt(x1) ) 可以是任意满足边界条件的可微函数:

  • ( t = 0 t = 0 t=0 ):起点通常是噪声分布(如 ( N ( 0 , I ) \mathcal{N}(0, I) N(0,I) ) 或接近噪声)。
  • ( t = 1 t = 1 t=1 ):终点接近目标样本 ( x 1 x_1 x1 )(如 ( N ( x 1 , σ min 2 I ) \mathcal{N}(x_1, \sigma_{\text{min}}^2 I) N(x1,σmin2I) ))。

这种灵活性让 Flow Matching 既能复现传统扩散过程,又能探索全新的路径设计。4.1 节提供了两个例子:一个连接扩散模型,一个基于最优传输(Optimal Transport, OT),展示了这种设计的多样性。


示例 I:扩散条件向量场

背景与定义

扩散模型(Diffusion Models)通过从数据加噪到纯噪声(前向过程),再从噪声去噪到数据(反向过程),实现生成。Flow Matching 可以复现这种扩散路径的反向过程,定义特定的 ( μ t ( x 1 ) \mu_t(x_1) μt(x1) ) 和 ( σ t ( x 1 ) \sigma_t(x_1) σt(x1) )。

1. Variance Exploding (VE) 路径
  • 路径
    p t ( x ∣ x 1 ) = N ( x ∣ x 1 , σ 1 − t 2 I ) p_t(x | x_1) = \mathcal{N}(x | x_1, \sigma_{1-t}^2 I) pt(xx1)=N(xx1,σ1t2I)
    • ( σ t \sigma_t σt ) 是递增函数,( σ 0 = 0 \sigma_0 = 0 σ0=0 ),( σ 1 ≫ 1 \sigma_1 \gg 1 σ11 )。
    • ( t = 0 t = 0 t=0 ):( p 0 ( x ∣ x 1 ) = N ( x 1 , σ 1 2 I ) p_0(x | x_1) = \mathcal{N}(x_1, \sigma_1^2 I) p0(xx1)=N(x1,σ12I) )(高方差噪声)。
    • ( t = 1 t = 1 t=1 ):( p 1 ( x ∣ x 1 ) = N ( x 1 , 0 ) p_1(x | x_1) = \mathcal{N}(x_1, 0) p1(xx1)=N(x1,0) )(集中在 ( x 1 x_1 x1 ))。
  • 设置
    μ t ( x 1 ) = x 1 , σ t ( x 1 ) = σ 1 − t \mu_t(x_1) = x_1, \quad \sigma_t(x_1) = \sigma_{1-t} μt(x1)=x1,σt(x1)=σ1t
  • 向量场(由 Theorem 3 的公式 15 ( u t ( x ∣ x 1 ) = σ ˙ t ( x 1 ) ⋅ x − μ t ( x 1 ) σ t ( x 1 ) + μ ˙ t ( x 1 ) u_t(x | x_1) = \dot{\sigma}_t(x_1) \cdot \frac{x - \mu_t(x_1)}{\sigma_t(x_1)} + \dot{\mu}_t(x_1) ut(xx1)=σ˙t(x1)σt(x1)xμt(x1)+μ˙t(x1) ) 导出):
    σ ˙ t ( x 1 ) = d d t σ 1 − t = − σ 1 − t ′ , μ ˙ t ( x 1 ) = 0 \dot{\sigma}_t(x_1) = \frac{d}{dt} \sigma_{1-t} = -\sigma'_{1-t}, \quad \dot{\mu}_t(x_1) = 0 σ˙t(x1)=dtdσ1t=σ1t,μ˙t(x1)=0
    u t ( x ∣ x 1 ) = − σ 1 − t ′ σ 1 − t ( x − x 1 ) u_t(x | x_1) = -\frac{\sigma'_{1-t}}{\sigma_{1-t}} (x - x_1) ut(xx1)=σ1tσ1t(xx1)
2. Variance Preserving (VP) 路径
  • 路径
    p t ( x ∣ x 1 ) = N ( x ∣ α 1 − t x 1 , ( 1 − α 1 − t 2 ) I ) p_t(x | x_1) = \mathcal{N}(x | \alpha_{1-t} x_1, (1 - \alpha_{1-t}^2) I) pt(xx1)=N(xα1tx1,(1α1t2)I)
    • ( α t = e − 1 2 T ( t ) \alpha_t = e^{-\frac{1}{2} T(t)} αt=e21T(t) ),( T ( t ) = ∫ 0 t β ( s )   d s T(t) = \int_0^t \beta(s) \, ds T(t)=0tβ(s)ds ),( β ( s ) \beta(s) β(s) ) 是噪声调度函数。
    • ( t = 0 t = 0 t=0 ):( p 0 ( x ∣ x 1 ) = N ( x 1 , 1 ) p_0(x | x_1) = \mathcal{N}(x_1, 1) p0(xx1)=N(x1,1) )(接近噪声)。
    • ( t = 1 t = 1 t=1 ):( p 1 ( x ∣ x 1 ) = N ( x 1 , 0 ) p_1(x | x_1) = \mathcal{N}(x_1, 0) p1(xx1)=N(x1,0) )(集中在 ( x 1 x_1 x1 ))。
  • 设置
    μ t ( x 1 ) = α 1 − t x 1 , σ t ( x 1 ) = 1 − α 1 − t 2 \mu_t(x_1) = \alpha_{1-t} x_1, \quad \sigma_t(x_1) = \sqrt{1 - \alpha_{1-t}^2} μt(x1)=α1tx1,σt(x1)=1α1t2
  • 向量场
    μ ˙ t ( x 1 ) = − α 1 − t ′ x 1 , σ ˙ t ( x 1 ) = − α 1 − t α 1 − t ′ 1 − α 1 − t 2 \dot{\mu}_t(x_1) = -\alpha'_{1-t} x_1, \quad \dot{\sigma}_t(x_1) = \frac{-\alpha_{1-t} \alpha'_{1-t}}{\sqrt{1 - \alpha_{1-t}^2}} μ˙t(x1)=α1tx1,σ˙t(x1)=1α1t2 α1tα1t
    u t ( x ∣ x 1 ) = − α 1 − t ′ 1 − α 1 − t 2 ( x − α 1 − t x 1 ) − α 1 − t ′ x 1 u_t(x | x_1) = -\frac{\alpha'_{1-t}}{\sqrt{1 - \alpha_{1-t}^2}} (x - \alpha_{1-t} x_1) - \alpha'_{1-t} x_1 ut(xx1)=1α1t2 α1t(xα1tx1)α1tx1
    化简后:
    u t ( x ∣ x 1 ) = − T ′ ( 1 − t ) 2 ⋅ e − T ( 1 − t ) x − e − 1 2 T ( 1 − t ) x 1 1 − e − T ( 1 − t ) u_t(x | x_1) = -\frac{T'(1 - t)}{2} \cdot \frac{e^{-T(1-t)} x - e^{-\frac{1}{2} T(1-t)} x_1}{1 - e^{-T(1-t)}} ut(xx1)=2T(1t)1eT(1t)eT(1t)xe21T(1t)x1

联系与特点

  • 与扩散模型的联系:这些路径与 Song 等人的概率流(Probability Flow)一致(见 Song et al., 2020b),但 Flow Matching 通过直接回归向量场 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(xx1) ),避免了扩散模型中复杂的得分匹配(Score Matching),实验中更稳定。
  • 局限:扩散路径受限于其起源(扩散过程的解),( t = 0 t = 0 t=0 ) 时并非纯噪声(如 ( N ( 0 , I ) \mathcal{N}(0, I) N(0,I) )),而是近似,需要额外调整。

示例 II:最优传输条件向量场

背景与定义

既然 Flow Matching 不依赖扩散过程,我们可以直接设计更自然的路径,比如基于 Wasserstein-2 最优传输(Optimal Transport, OT)的路径,让均值和方差线性变化。

  • 路径
    μ t ( x 1 ) = t x 1 , σ t ( x 1 ) = 1 − ( 1 − σ min ) t \mu_t(x_1) = t x_1, \quad \sigma_t(x_1) = 1 - (1 - \sigma_{\text{min}}) t μt(x1)=tx1,σt(x1)=1(1σmin)t
    • ( t = 0 t = 0 t=0 ):( p 0 ( x ∣ x 1 ) = N ( 0 , 1 ) p_0(x | x_1) = \mathcal{N}(0, 1) p0(xx1)=N(0,1) )(标准正态)。
    • ( t = 1 t = 1 t=1 ):( p 1 ( x ∣ x 1 ) = N ( x 1 , σ min 2 ) p_1(x | x_1) = \mathcal{N}(x_1, \sigma_{\text{min}}^2) p1(xx1)=N(x1,σmin2) )(接近 ( x 1 x_1 x1 ))。

  • ψ t ( x ) = ( 1 − ( 1 − σ min ) t ) x + t x 1 \psi_t(x) = (1 - (1 - \sigma_{\text{min}}) t) x + t x_1 ψt(x)=(1(1σmin)t)x+tx1
  • 向量场
    μ ˙ t ( x 1 ) = x 1 , σ ˙ t ( x 1 ) = − ( 1 − σ min ) \dot{\mu}_t(x_1) = x_1, \quad \dot{\sigma}_t(x_1) = -(1 - \sigma_{\text{min}}) μ˙t(x1)=x1,σ˙t(x1)=(1σmin)
    u t ( x ∣ x 1 ) = − ( 1 − σ min ) ⋅ x − t x 1 1 − ( 1 − σ min ) t + x 1 = x 1 − ( 1 − σ min ) x 1 − ( 1 − σ min ) t u_t(x | x_1) = -(1 - \sigma_{\text{min}}) \cdot \frac{x - t x_1}{1 - (1 - \sigma_{\text{min}}) t} + x_1 = \frac{x_1 - (1 - \sigma_{\text{min}}) x}{1 - (1 - \sigma_{\text{min}}) t} ut(xx1)=(1σmin)1(1σmin)txtx1+x1=1(1σmin)tx1(1σmin)x
  • 损失
    L CFM ( θ ) = E t , q ( x 1 ) , p ( x 0 ) ∥ v t ( ψ t ( x 0 ) ) − x 1 − ( 1 − σ min ) x 0 1 − ( 1 − σ min ) t ∥ 2 \mathcal{L}_{\text{CFM}}(\theta) = \mathbb{E}_{t, q(x_1), p(x_0)} \left\| v_t(\psi_t(x_0)) - \frac{x_1 - (1 - \sigma_{\text{min}}) x_0}{1 - (1 - \sigma_{\text{min}}) t} \right\|^2 LCFM(θ)=Et,q(x1),p(x0) vt(ψt(x0))1(1σmin)tx1(1σmin)x0 2

最优传输的意义

  • OT 特性:( ψ t ( x ) \psi_t(x) ψt(x) ) 是从 ( p 0 ( x ∣ x 1 ) = N ( 0 , 1 ) p_0(x | x_1) = \mathcal{N}(0, 1) p0(xx1)=N(0,1) ) 到 ( p 1 ( x ∣ x 1 ) = N ( x 1 , σ min 2 ) p_1(x | x_1) = \mathcal{N}(x_1, \sigma_{\text{min}}^2) p1(xx1)=N(x1,σmin2) ) 的 OT 位移映射(Displacement Map),路径为:
    p t = [ ( 1 − t ) id + t ψ ] ∗ p 0 p_t = [(1 - t) \text{id} + t \psi]_* p_0 pt=[(1t)id+tψ]p0
    粒子沿直线移动,速度恒定(见 McCann, 1997)。
  • 优势
    • 相比扩散路径(可能“超调”后回溯),OT 路径始终直线前进,避免浪费。
    • 定义在整个 ( t ∈ [ 0 , 1 ] t \in [0, 1] t[0,1] ),控制更直接。

总结

4.1 节展示了 Flow Matching 的强大灵活性:

  • 扩散条件路径复现了传统扩散模型的反向过程(如 VE 和 VP),通过解析向量场提升训练效率,但受限于扩散起源。
  • 最优传输路径跳出扩散框架,利用 OT 理论设计直线路径,简单高效且直观。

这些实例不仅连接了已有方法,还开辟了新思路。无论是借鑒扩散模型的经验,还是探索 OT 的优雅,Flow Matching 都为生成模型提供了丰富的可能性。希望这篇博客让你对高斯条件概率路径的设计有更直观的认识!如果想深入某个实例,欢迎继续讨论。

Flow Matching 进阶:高斯条件概率路径的特殊实例(补充版)

在《Flow Matching for Generative Modeling》第 4.1 节“Special Instances of Gaussian Conditional Probability Paths”(高斯条件概率路径的特殊实例)中,我们探讨了如何通过调整 ( μ t ( x 1 ) \mu_t(x_1) μt(x1) ) 和 ( σ t ( x 1 ) \sigma_t(x_1) σt(x1) ) 设计不同的生成路径。本文在之前的博客基础上,补充了对 Variance Exploding (VE) 路径和 Variance Preserving (VP) 路径的详细解释,以及对最优传输(OT)路径中 OT 位移映射公式的拆解,帮助你更全面地理解这些概念的来龙去脉。


示例 I:扩散条件向量场

什么是 Variance Exploding (VE) 和 Variance Preserving (VP) 路径?

扩散模型(Diffusion Models)通过前向过程加噪、反向过程去噪生成数据,其概率路径通常是高斯形式。VE 和 VP 是两种常见的路径设计,分别对应不同的噪声スケール策略。

Variance Exploding (VE) 路径
  • 定义
    p t ( x ∣ x 1 ) = N ( x ∣ x 1 , σ 1 − t 2 I ) p_t(x | x_1) = \mathcal{N}(x | x_1, \sigma_{1-t}^2 I) pt(xx1)=N(xx1,σ1t2I)
    • ( μ t ( x 1 ) = x 1 \mu_t(x_1) = x_1 μt(x1)=x1 ),均值固定。
    • ( σ t ( x 1 ) = σ 1 − t \sigma_t(x_1) = \sigma_{1-t} σt(x1)=σ1t ),( σ t \sigma_t σt ) 是递增函数,( σ 0 = 0 \sigma_0 = 0 σ0=0 ),( σ 1 ≫ 1 \sigma_1 \gg 1 σ11 )(如 1000)。
    • ( t = 0 t = 0 t=0 ):( p 0 ( x ∣ x 1 ) = N ( x 1 , σ 1 2 I ) p_0(x | x_1) = \mathcal{N}(x_1, \sigma_1^2 I) p0(xx1)=N(x1,σ12I) )(高方差噪声)。
    • ( t = 1 t = 1 t=1 ):( p 1 ( x ∣ x 1 ) = N ( x 1 , 0 ) p_1(x | x_1) = \mathcal{N}(x_1, 0) p1(xx1)=N(x1,0) )(集中在 ( x 1 x_1 x1 ))。
  • 向量场
    u t ( x ∣ x 1 ) = − σ 1 − t ′ σ 1 − t ( x − x 1 ) u_t(x | x_1) = -\frac{\sigma'_{1-t}}{\sigma_{1-t}} (x - x_1) ut(xx1)=σ1tσ1t(xx1)
  • 特点:方差“爆炸式”增长,前向过程从数据加噪到接近纯噪声(( σ 1 2 \sigma_1^2 σ12 ) 很大),反向过程则逐渐减小方差。
Variance Preserving (VP) 路径
  • 定义
    p t ( x ∣ x 1 ) = N ( x ∣ α 1 − t x 1 , ( 1 − α 1 − t 2 ) I ) p_t(x | x_1) = \mathcal{N}(x | \alpha_{1-t} x_1, (1 - \alpha_{1-t}^2) I) pt(xx1)=N(xα1tx1,(1α1t2)I)
    • ( α t = e − 1 2 T ( t ) \alpha_t = e^{-\frac{1}{2} T(t)} αt=e21T(t) ),( T ( t ) = ∫ 0 t β ( s )   d s T(t) = \int_0^t \beta(s) \, ds T(t)=0tβ(s)ds ),( β ( s ) \beta(s) β(s)) 是噪声调度函数。
    • ( μ t ( x 1 ) = α 1 − t x 1 \mu_t(x_1) = \alpha_{1-t} x_1 μt(x1)=α1tx1 ),均值随时间缩减。
    • ( σ t ( x 1 ) = 1 − α 1 − t 2 \sigma_t(x_1) = \sqrt{1 - \alpha_{1-t}^2} σt(x1)=1α1t2 )。
    • ( t = 0 t = 0 t=0 ):( p 0 ( x ∣ x 1 ) = N ( x 1 , 1 ) p_0(x | x_1) = \mathcal{N}(x_1, 1) p0(xx1)=N(x1,1) )(接近噪声)。
    • ( t = 1 t = 1 t=1 ):( p 1 ( x ∣ x 1 ) = N ( x 1 , 0 ) p_1(x | x_1) = \mathcal{N}(x_1, 0) p1(xx1)=N(x1,0) )(集中在 ( x 1 x_1 x1 ))。
  • 向量场
    u t ( x ∣ x 1 ) = − T ′ ( 1 − t ) 2 ⋅ e − T ( 1 − t ) x − e − 1 2 T ( 1 − t ) x 1 1 − e − T ( 1 − t ) u_t(x | x_1) = -\frac{T'(1 - t)}{2} \cdot \frac{e^{-T(1-t)} x - e^{-\frac{1}{2} T(1-t)} x_1}{1 - e^{-T(1-t)}} ut(xx1)=2T(1t)1eT(1t)eT(1t)xe21T(1t)x1
  • 特点:方差保持在一定范围内(从 1 到 0),均值逐渐从 ( x 1 x_1 x1 ) 缩到 0(前向),反向则恢复。
VE 和 VP 的不同
  • 方差行为
    • VE:方差从 0 增加到非常大(如 ( σ 1 2 ≫ 1 \sigma_1^2 \gg 1 σ121 )),强调“爆炸”效应。
    • VP:方差从 0 到 1 再回到 0,保持相对稳定,避免过大噪声。
  • 均值行为
    • VE:均值始终是 ( x 1 x_1 x1 ),不变。
    • VP:均值从 ( x 1 x_1 x1 ) 缩到接近 0(前向),反向恢复到 ( x 1 x_1 x1 )。
  • 适用场景
    • VE:适合需要强噪声的任务(如高维数据),但计算复杂度高。
    • VP:更平衡,常用在图像生成(如 DDPM),稳定性更好。
为什么提出这两种路径?
  • VE 的起源:Sohl-Dickstein 等(2015)提出扩散模型时,VE 路径通过大方差模拟纯噪声,便于理论分析和采样。
  • VP 的改进:Ho 等(2020)发现 VE 的极端方差会导致训练不稳定,提出 VP 路径,通过控制 ( \beta(t) ) 保持方差适中,提升生成质量。Song 等(2020b)进一步将 VP 路径用于概率流,奠定了其主流地位。
  • 动机:两者都旨在通过高斯路径简化扩散过程的闭式解,但 VE 更理论化,VP 更实用化。

示例 II:最优传输条件向量场

定义与路径

  • 设置
    μ t ( x 1 ) = t x 1 , σ t ( x 1 ) = 1 − ( 1 − σ min ) t \mu_t(x_1) = t x_1, \quad \sigma_t(x_1) = 1 - (1 - \sigma_{\text{min}}) t μt(x1)=tx1,σt(x1)=1(1σmin)t
    • ( t = 0 t = 0 t=0 ):( p 0 ( x ∣ x 1 ) = N ( 0 , 1 ) p_0(x | x_1) = \mathcal{N}(0, 1) p0(xx1)=N(0,1) )。
    • ( t = 1 t = 1 t=1 ):( p 1 ( x ∣ x 1 ) = N ( x 1 , σ min 2 ) p_1(x | x_1) = \mathcal{N}(x_1, \sigma_{\text{min}}^2) p1(xx1)=N(x1,σmin2) )(( σ min \sigma_{\text{min}} σmin ) 通常很小)。

  • ψ t ( x ) = ( 1 − ( 1 − σ min ) t ) x + t x 1 \psi_t(x) = (1 - (1 - \sigma_{\text{min}}) t) x + t x_1 ψt(x)=(1(1σmin)t)x+tx1
  • 向量场
    u t ( x ∣ x 1 ) = x 1 − ( 1 − σ min ) x 1 − ( 1 − σ min ) t u_t(x | x_1) = \frac{x_1 - (1 - \sigma_{\text{min}}) x}{1 - (1 - \sigma_{\text{min}}) t} ut(xx1)=1(1σmin)tx1(1σmin)x

OT 特性详解

  • OT 特性
    p t = [ ( 1 − t ) id + t ψ ] ∗ p 0 p_t = [(1 - t) \text{id} + t \psi]_* p_0 pt=[(1t)id+tψ]p0
    • ( ψ t ( x ) \psi_t(x) ψt(x) ) 是从 ( p 0 ( x ∣ x 1 ) = N ( 0 , 1 ) p_0(x | x_1) = \mathcal{N}(0, 1) p0(xx1)=N(0,1) ) 到 ( p 1 ( x ∣ x 1 ) = N ( x 1 , σ min 2 ) p_1(x | x_1) = \mathcal{N}(x_1, \sigma_{\text{min}}^2) p1(xx1)=N(x1,σmin2) ) 的 OT 位移映射(Displacement Map)。
    • 粒子沿直线移动,速度恒定(见 McCann, 1997)。
公式拆解
  1. ( id \text{id} id ) 是什么?

    • ( id \text{id} id ) 是恒等映射(identity map),即 ( id ( x ) = x \text{id}(x) = x id(x)=x )。它表示不做任何变换,直接输出输入本身。
    • 在这里,( id \text{id} id ) 代表路径的起点 ( p 0 p_0 p0 ) 的位置。
  2. ( ( 1 − t ) id + t ψ (1 - t) \text{id} + t \psi (1t)id+tψ ) 是什么?

    • 这是一个线性插值函数,称为 OT 位移映射:
      ( 1 − t ) id ( x ) + t ψ ( x ) = ( 1 − t ) x + t ψ ( x ) (1 - t) \text{id}(x) + t \psi(x) = (1 - t) x + t \psi(x) (1t)id(x)+tψ(x)=(1t)x+tψ(x)
    • 代入 ( ψ ( x ) = x 1 \psi(x) = x_1 ψ(x)=x1 )(假设 ( σ min = 0 \sigma_{\text{min}} = 0 σmin=0 ) 简化理解):
      ( 1 − t ) x + t x 1 (1 - t) x + t x_1 (1t)x+tx1
    • ( t = 0 t = 0 t=0 ):输出 ( x x x )(起点)。
    • ( t = 1 t = 1 t=1 ):输出 ( x 1 x_1 x1 )(终点)。
    • 随着 ( t t t ) 从 0 到 1,粒子从 ( x x x ) 直线移动到 ( x 1 x_1 x1 )。
  3. ( [ ( 1 − t ) id + t ψ ] ∗ p 0 [(1 - t) \text{id} + t \psi]_* p_0 [(1t)id+tψ]p0 ) 是什么?

    • ( [ ⋅ ] ∗ [\cdot]_* [] ) 表示推前(push-forward)操作,将分布 ( p 0 p_0 p0 ) 通过映射变换到新分布。
    • ( p 0 ( x ∣ x 1 ) = N ( 0 , 1 ) p_0(x | x_1) = \mathcal{N}(0, 1) p0(xx1)=N(0,1) ) 是初始分布。
    • 令 ( ϕ t ( x ) = ( 1 − t ) x + t ψ ( x ) = ( 1 − ( 1 − σ min ) t ) x + t x 1 \phi_t(x) = (1 - t) x + t \psi(x) = (1 - (1 - \sigma_{\text{min}}) t) x + t x_1 ϕt(x)=(1t)x+tψ(x)=(1(1σmin)t)x+tx1 ):
      • ( x ∼ N ( 0 , 1 ) x \sim \mathcal{N}(0, 1) xN(0,1) )。
      • ( ϕ t ( x ) ∼ N ( t x 1 , ( 1 − ( 1 − σ min ) t ) 2 I ) \phi_t(x) \sim \mathcal{N}(t x_1, (1 - (1 - \sigma_{\text{min}}) t)^2 I) ϕt(x)N(tx1,(1(1σmin)t)2I) )(均值和方差通过线性变换计算)。
    • 因此:
      p t ( x ) = [ ( 1 − t ) id + t ψ ] ∗ p 0 = N ( x ∣ t x 1 , ( 1 − ( 1 − σ min ) t ) 2 I ) p_t(x) = [(1 - t) \text{id} + t \psi]_* p_0 = \mathcal{N}(x | t x_1, (1 - (1 - \sigma_{\text{min}}) t)^2 I) pt(x)=[(1t)id+tψ]p0=N(xtx1,(1(1σmin)t)2I)
    • 这与定义一致,验证了 OT 路径的正确性。
直观理解
  • 直线移动:粒子从 ( x ∼ N ( 0 , 1 ) x \sim \mathcal{N}(0, 1) xN(0,1) ) 到 ( x 1 x_1 x1 ) 的路径是直线,速度 ( x 1 − x 1 \frac{x_1 - x}{1} 1x1x ) 恒定。
  • OT 的优雅:OT 路径是最优的(Wasserstein-2 距离最小),避免了扩散路径可能出现的“超调”或迂回。

总结

4.1 节通过 VE 和 VP 路径展示了 Flow Matching 与扩散模型的联系:

  • VE:方差爆炸,适合强噪声场景。
  • VP:方差平稳,更实用且稳定。

而 OT 路径则跳出扩散框架:

  • 通过 ( p t = [ ( 1 − t ) id + t ψ ] ∗ p 0 p_t = [(1 - t) \text{id} + t \psi]_* p_0 pt=[(1t)id+tψ]p0 ),利用 OT 位移映射实现直线高效移动,简单且优雅。

这些设计不仅复现了传统方法,还拓展了新可能。希望这篇补充博客让你对 VE、VP 和 OT 路径的细节一目了然!

Flow Matching 的训练和推理代码示例

下面将提供一个基于 Flow Matching 的训练和推理代码示例,使用最优传输(Optimal Transport, OT)路径。我们将聚焦于如何计算流 ( ψ t ( x ) \psi_t(x) ψt(x) ) 和向量场 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(xx1) ),并展示如何为每个目标样本 ( x 1 x_1 x1 ) 设计专属路径,让噪声分布 ( N ( 0 , I ) \mathcal{N}(0, I) N(0,I) ) 逐步变成目标分布 ( q ( x 1 ) q(x_1) q(x1) )。由于直接使用 ImageNet 需要大量预处理和计算资源,我将以 MNIST 数据集为例(一个简单的图片数据集),并在代码中注释说明如何扩展到 ImageNet。


代码设计思路

  1. 数据集:使用 MNIST(28x28 灰度图像,展平为 784 维向量),模拟 ( q ( x 1 ) q(x_1) q(x1) )。扩展到 ImageNet 时,只需替换数据加载器并调整网络结构。
  2. OT 路径
    • ( μ t ( x 1 ) = t x 1 \mu_t(x_1) = t x_1 μt(x1)=tx1 ):均值线性插值。
    • ( σ t ( x 1 ) = 1 − ( 1 − σ min ) t \sigma_t(x_1) = 1 - (1 - \sigma_{\text{min}}) t σt(x1)=1(1σmin)t ):标准差从 1 线性减小到 ( σ min \sigma_{\text{min}} σmin )(如 0.01)。
    • 流:( ψ t ( x ) = ( 1 − ( 1 − σ min ) t ) x + t x 1 \psi_t(x) = (1 - (1 - \sigma_{\text{min}}) t) x + t x_1 ψt(x)=(1(1σmin)t)x+tx1 )。
    • 向量场:( u t ( x ∣ x 1 ) = x 1 − ( 1 − σ min ) x 1 − ( 1 − σ min ) t u_t(x | x_1) = \frac{x_1 - (1 - \sigma_{\text{min}}) x}{1 - (1 - \sigma_{\text{min}}) t} ut(xx1)=1(1σmin)tx1(1σmin)x )。
  3. 训练目标:通过条件流匹配损失 ( L CFM = E t , q ( x 1 ) , p ( x 0 ) ∥ v θ ( ψ t ( x 0 ) , t ) − u t ( ψ t ( x 0 ) ∣ x 1 ) ∥ 2 \mathcal{L}_{\text{CFM}} = \mathbb{E}_{t, q(x_1), p(x_0)} \| v_\theta(\psi_t(x_0), t) - u_t(\psi_t(x_0) | x_1) \|^2 LCFM=Et,q(x1),p(x0)vθ(ψt(x0),t)ut(ψt(x0)x1)2 ) 训练神经网络 ( v θ v_\theta vθ ) 逼近 ( u t u_t ut )。
  4. 推理:从噪声 ( x 0 ∼ N ( 0 , I ) x_0 \sim \mathcal{N}(0, I) x0N(0,I) ) 开始,用学到的 ( v θ v_\theta vθ ) 解 ODE 生成样本。

训练代码

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np

# 设备设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 数据加载:MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# 扩展到 ImageNet:替换为 ImageNet 数据加载器,例如:
# from torchvision.datasets import ImageNet
# train_dataset = ImageNet(root='./imagenet', split='train', transform=transform)
# 输入维度调整为 3x224x224(ImageNet)并展平为 3*224*224=150528

# 超参数
dim = 784  # MNIST 展平维度 (28*28)
sigma_min = 0.01  # OT 路径的最小方差
num_epochs = 10
lr = 1e-3

# 神经网络:预测向量场 v_\theta(x, t)
class VectorFieldNet(nn.Module):
    def __init__(self, input_dim):
        super(VectorFieldNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim + 1, 512),  # 输入 x 和 t
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, input_dim)  # 输出向量场
        )
    
    def forward(self, x, t):
        t = t.view(-1, 1)  # [batch_size, 1]
        xt = torch.cat([x, t], dim=1)  # 拼接 x 和 t
        return self.net(xt)

# OT 路径相关函数
def psi_t(x, x1, t, sigma_min=0.01):
    """计算流 \psi_t(x)"""
    scale = 1 - (1 - sigma_min) * t  # \sigma_t(x_1)
    return scale * x + t * x1  # (1 - (1 - \sigma_min) t) x + t x_1

def u_t(x, x1, t, sigma_min=0.01):
    """计算目标向量场 u_t(x | x_1)"""
    scale = 1 - (1 - sigma_min) * t
    return (x1 - (1 - sigma_min) * x) / scale  # (x_1 - (1 - \sigma_min) x) / (1 - (1 - \sigma_min) t)

# 初始化模型和优化器
model = VectorFieldNet(dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

# 训练循环
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch_idx, (x1, _) in enumerate(train_loader):
        x1 = x1.view(-1, dim).to(device)  # x_1 来自 q(x_1),[batch_size, 784]
        batch_size = x1.size(0)

        # 采样 t 和 x_0
        t = torch.rand(batch_size, device=device)  # t ~ U[0, 1]
        x0 = torch.randn(batch_size, dim, device=device)  # x_0 ~ N(0, I)

        # 计算 \psi_t(x_0)
        xt = psi_t(x0, x1, t, sigma_min)

        # 计算目标向量场 u_t
        target_u = u_t(xt, x1, t, sigma_min)

        # 预测向量场 v_\theta
        pred_u = model(xt, t)

        # 计算损失
        loss = torch.mean((pred_u - target_u) ** 2)  # \mathcal{L}_{\text{CFM}}
        total_loss += loss.item()

        # 优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.6f}")

# 保存模型
torch.save(model.state_dict(), "flow_matching_ot.pth")

代码说明

  1. 数据加载

    • MNIST 图像展平为 784 维向量,归一化到 [-1, 1]。
    • 扩展到 ImageNet:只需替换数据集为 ImageNet,维度调整为 150528(3x224x224),并可能需要更复杂的网络结构(如 CNN)。
  2. 流 ( ψ t ( x ) \psi_t(x) ψt(x) )

    • 函数 psi_t 计算 ( ψ t ( x ) = ( 1 − ( 1 − σ min ) t ) x + t x 1 \psi_t(x) = (1 - (1 - \sigma_{\text{min}}) t) x + t x_1 ψt(x)=(1(1σmin)t)x+tx1 )。
    • ( x x x ) 是初始噪声 ( x 0 ∼ N ( 0 , I ) x_0 \sim \mathcal{N}(0, I) x0N(0,I) ),( x 1 x_1 x1 ) 是目标样本,( t t t ) 控制插值进度。
  3. 向量场 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(xx1) )

    • 函数 u_t 计算 ( u t ( x ∣ x 1 ) = x 1 − ( 1 − σ min ) x 1 − ( 1 − σ min ) t u_t(x | x_1) = \frac{x_1 - (1 - \sigma_{\text{min}}) x}{1 - (1 - \sigma_{\text{min}}) t} ut(xx1)=1(1σmin)tx1(1σmin)x )。
    • ( x = ψ t ( x 0 ) x = \psi_t(x_0) x=ψt(x0) ) 是当前时间 ( t t t ) 的状态,( u t u_t ut ) 是目标方向。
  4. 训练

    • 采样 ( t ∼ U [ 0 , 1 ] t \sim U[0, 1] tU[0,1] ),( x 0 ∼ N ( 0 , I ) x_0 \sim \mathcal{N}(0, I) x0N(0,I) ),计算 ( ψ t ( x 0 ) \psi_t(x_0) ψt(x0)) 和 ( u t u_t ut )。
    • 神经网络 ( v θ v_\theta vθ ) 预测 ( u t u_t ut ),通过均方误差优化。

推理代码

import torch
from torchdiffeq import odeint  # 需要安装 torchdiffeq: pip install torchdiffeq

# 加载模型
model = VectorFieldNet(dim).to(device)
model.load_state_dict(torch.load("flow_matching_ot.pth"))
model.eval()

# ODE 解算器封装向量场
def vector_field(t, x):
    """将 v_\theta 封装为 ODE 函数"""
    t_tensor = torch.ones(x.size(0), device=device) * t  # 扩展 t 到 batch_size
    with torch.no_grad():
        return model(x, t_tensor)

# 推理:从噪声生成样本
num_samples = 16
x0 = torch.randn(num_samples, dim, device=device)  # x_0 ~ N(0, I)
t_span = torch.linspace(0, 1, 100, device=device)  # 时间步长

# 解 ODE
x_t = odeint(vector_field, x0, t_span, method='rk4')  # [num_steps, batch_size, dim]
x1_pred = x_t[-1]  # 取 t=1 的结果

# 可视化(仅适用于 MNIST)
import matplotlib.pyplot as plt

x1_pred = x1_pred.view(-1, 28, 28).cpu().numpy()  # 重塑为图像
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for i, ax in enumerate(axes.flat):
    ax.imshow(x1_pred[i], cmap='gray')
    ax.axis('off')
plt.show()

代码说明

  1. 推理过程

    • 从 ( x 0 ∼ N ( 0 , I ) x_0 \sim \mathcal{N}(0, I) x0N(0,I) ) 开始,使用学到的 ( v θ v_\theta vθ ) 作为向量场。
    • 通过 ODE 求解器(这里用 Runge-Kutta 4 阶方法)积分从 ( t = 0 t = 0 t=0 ) 到 ( t = 1 t = 1 t=1 ),得到 ( x 1 x_1 x1 )。
  2. 流 ( ψ t ( x ) \psi_t(x) ψt(x) ) 的计算

    • 在训练中,( ψ t ( x ) \psi_t(x) ψt(x) ) 用于生成中间状态 ( x t x_t xt )。
    • 在推理中,( ψ t ( x ) \psi_t(x) ψt(x) ) 被 ( v θ v_\theta vθ ) 隐式逼近,ODE 解算器直接输出路径。
  3. 扩展到 ImageNet

    • 输入维度改为 150528,网络可能需要用 CNN(如 U-Net)。
    • 可视化需调整为彩色图像(3x224x224)。

如何为每个 ( x 1 x_1 x1 ) 设计专属路径?

实现原理

  1. 条件化设计

    • 对于每个 ( x 1 ∼ q ( x 1 ) x_1 \sim q(x_1) x1q(x1) )(如 MNIST 的图像),初始噪声 ( x 0 ∼ N ( 0 , I ) x_0 \sim \mathcal{N}(0, I) x0N(0,I) ) 是通用的。
    • OT 路径通过 ( ψ t ( x 0 ) = ( 1 − ( 1 − σ min ) t ) x 0 + t x 1 \psi_t(x_0) = (1 - (1 - \sigma_{\text{min}}) t) x_0 + t x_1 ψt(x0)=(1(1σmin)t)x0+tx1 ) 为每个 ( x 1 x_1 x1 ) 定义一条从 ( x 0 x_0 x0 ) 到 ( x 1 x_1 x1 ) 的专属路径:
      • ( t = 0 t = 0 t=0 ):( x 0 x_0 x0 )(噪声)。
      • ( t = 1 t = 1 t=1 ):接近 ( x 1 x_1 x1 )(目标)。
  2. 向量场的作用

    • ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(xx1) ) 是解析定义的,依赖于当前 ( x x x ) 和目标 ( x 1 x_1 x1 ),指导 ( x x x ) 沿路径移动。
    • 神经网络 ( v θ v_\theta vθ ) 学习逼近 ( u t u_t ut ),使其适用于任意 ( x 1 x_1 x1 )。
  3. 训练中的实现

    • 每次迭代从数据加载器采样 ( x 1 x_1 x1 ),随机生成 ( t t t ) 和 ( x 0 x_0 x0 )。
    • 计算 ( ψ t ( x 0 ) \psi_t(x_0) ψt(x0) ) 和 ( u t ( ψ t ( x 0 ) ∣ x 1 ) u_t(\psi_t(x_0) | x_1) ut(ψt(x0)x1) ),让 ( v θ v_\theta vθ ) 匹配这条专属路径的方向。
  4. 推理中的实现

    • 从 ( x 0 x_0 x0 ) 开始,( v θ v_\theta vθ ) 驱动 ODE,隐式生成从噪声到数据的路径。
    • 虽然训练时路径针对每个 ( x 1 x_1 x1 ),推理时无需指定 ( x 1 x_1 x1 ),( v θ v_\theta vθ ) 已学到通用映射。

直观理解

  • 专属路径:训练时,( x 1 x_1 x1 ) 作为条件输入,( ψ t \psi_t ψt ) 和 ( u t u_t ut) 明确指向 ( x 1 x_1 x1 ),像为每个目标画一条“导航线”。
  • 逐步变成目标:( σ t \sigma_t σt ) 从 1 减到 ( σ min \sigma_{\text{min}} σmin ),( μ t \mu_t μt ) 从 0 移到 ( x 1 x_1 x1 ),噪声逐渐“聚焦”到 ( x 1 x_1 x1 )。

总结

这段代码展示了 Flow Matching 使用 OT 路径的完整流程:

  • 训练:通过 ( ψ t ( x ) \psi_t(x) ψt(x) ) 和 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(xx1) ) 为每个 ( x 1 x_1 x1 ) 设计路径,优化 ( v θ v_\theta vθ )。
  • 推理:从噪声解 ODE,直接生成样本。

相比扩散模型,Flow Matching 跳过了前向加噪,直接从噪声到数据的路径更高效。扩展到 ImageNet 只需调整数据和网络规模,核心逻辑不变。希望这篇代码和解释让你明白如何将理论落地!

后记

2025年4月7日22点46分于上海,在grok 3大模型辅助下完成。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值