论文提出了一种新的生成模型。论文的目的是给定一个目标分布,有目标分布的一定量的样本,但是不知道目标分布的概率密度函数,学习一个模型能生成服从目标分布的新样本。
Flow Matching (FM)是一种训练连续标准化流Continuous Normalizing Flow (CNF)的方法。
FM是一种通用的方法。FM可以用于训练扩散路径,用FM训练扩散路径更稳定。FM也可以用于训练其他路径,一个例子是训练最优传输(OT)位移插值定义的条件概率路径,这些路径比扩散路径更有效,提供更快的训练和采样,从而获得更好的泛化效果。
核心的思想是把无条件估计问题的转换为有条件的问题的来学习。作者说是从denoised score matching得到的启发:
We first show that we can construct such target vector fields through per-example (i.e., conditional) formulations. Then, inspired by denoising score matching, we show that a per-example training objective, termed Conditional Flow Matching (CFM), provides equivalent gradients and does not require explicit knowledge of the intractable target vector field.
连续标准化流
数据点
x
∈
R
d
\pmb x \in \mathbb R^d
x∈Rd,时变概率密度路径
p
:
[
0
,
1
]
×
R
d
→
R
>
0
p:[0,1] \times \mathbb R^d \rightarrow \mathbb R_{>0}
p:[0,1]×Rd→R>0,时变向量场
v
t
:
[
0
,
1
]
×
R
d
→
R
d
v_t:[0,1] \times \mathbb R^d \rightarrow \mathbb R^d
vt:[0,1]×Rd→Rd。
流flow把一个分布映射成另一个分布,可以通过常微分方程用
v
t
v_t
vt构建flow
ϕ
:
[
0
,
1
]
×
R
d
→
R
d
\phi:[0,1] \times \mathbb R^d \rightarrow \mathbb R^d
ϕ:[0,1]×Rd→Rd:
d
ϕ
t
(
x
)
d
t
=
v
t
(
ϕ
t
(
x
)
)
ϕ
0
(
x
)
=
x
(1)
\frac{d\phi_t(\pmb x)}{dt}=v_t(\phi_t(\pmb x)) \tag{1} \\ \phi_0(\pmb x)=\pmb x
dtdϕt(x)=vt(ϕt(x))ϕ0(x)=x(1)时变向量场可以用神经网络
v
t
(
x
;
θ
)
v_t(\pmb x; \theta)
vt(x;θ)来建模,这样构建的flow
ϕ
t
\phi_t
ϕt叫做连续标准化流(Continuous Normalizing Flow,CNF)。CNF通常用于把一个简单的分布
p
0
p_0
p0变成一个复杂的分布
p
1
p_1
p1,其符合push-forward方程:
p
t
(
x
)
=
[
ϕ
t
]
⋆
p
0
(
x
)
=
p
0
(
ϕ
t
−
1
(
x
)
)
det
[
∂
ϕ
t
−
1
∂
x
(
x
)
]
p_t(x)=[\phi_t]_\star p_0(x)=p_0(\phi_t^{-1}(x))\det[\frac{\partial \phi_t^{-1}}{\partial x}(x)]
pt(x)=[ϕt]⋆p0(x)=p0(ϕt−1(x))det[∂x∂ϕt−1(x)]我们的目标是采样服从复杂目标分布的样本,方法是首先随机采样服从简单分布的噪声样本
x
∼
N
(
0
,
I
)
\pmb x \sim \mathcal N (\pmb 0, \pmb I)
x∼N(0,I),然后使用ODE求解器在区间
t
∈
[
0
,
1
]
t \in [0, 1]
t∈[0,1]上使用训练得到的向量场
v
t
v_t
vt求解方程(1)得到服从目标分布的样本
ϕ
1
(
x
)
\phi_1(\pmb x)
ϕ1(x)。所以主要的问题是如何学习
v
t
(
x
;
θ
)
v_t(\pmb x; \theta)
vt(x;θ)。
Flow Matching(FM)
用
x
1
\pmb x_1
x1表示服从未知的目标分布
q
(
x
1
)
q(\pmb x_1)
q(x1)的随机变量,我们不知道
q
(
x
1
)
q(\pmb x_1)
q(x1)的密度函数,但可以获得服从
q
(
x
1
)
q(\pmb x_1)
q(x1)的样本。用
p
t
p_t
pt表示概率密度路径,
p
0
p_0
p0服从标准高斯分布,
p
1
p_1
p1近似
q
q
q。
Flow Matching的训练目标是学习
v
t
v_t
vt,损失函数是
L
F
M
(
θ
)
=
E
t
,
p
t
(
x
)
∥
v
t
(
x
;
θ
)
−
u
t
(
x
)
∥
2
\mathcal L_{FM}(\theta)=\mathbb E_{t,p_t(\pmb x)}\|v_t(\pmb x; \theta)-u_t(\pmb x)\|^2
LFM(θ)=Et,pt(x)∥vt(x;θ)−ut(x)∥2流匹配的损失函数很简单,但在实践中没法使用,因为我们不知道如何定义合适的
p
t
p_t
pt和
u
t
u_t
ut。
Conditional Flow Matching(CFM)
为了解决上面的问题,考虑条件流匹配。条件流匹配的损失函数是
L
C
F
M
(
θ
)
=
E
t
,
q
(
x
1
)
,
p
t
(
x
∣
x
1
)
∥
v
t
(
x
;
θ
)
−
u
t
(
x
∣
x
1
)
∥
2
\mathcal L_{CFM}(\theta)=\mathbb E_{t,q(\pmb x_1),p_t(\pmb x|\pmb x_1)}\|v_t(\pmb x; \theta)-u_t(\pmb x|\pmb x_1)\|^2
LCFM(θ)=Et,q(x1),pt(x∣x1)∥vt(x;θ)−ut(x∣x1)∥2与流匹配的目标不同,条件流匹配的目标允许我们轻松地对无偏估计进行采样,只要我们可以从
p
t
(
x
∣
x
1
)
p_t(\pmb x|\pmb x_1)
pt(x∣x1) 有效地采样并计算
u
t
(
x
∣
x
1
)
u_t(\pmb x|\pmb x_1)
ut(x∣x1),这两者都可以很容易地完成,因为它们是对每个样本定义的。
论文中证明了优化CFM目标等价于优化FM目标(从期望的角度)。所以,剩下的问题是如何设计合适的条件概率路径
p
t
(
x
∣
x
1
)
p_t(\pmb x|\pmb x_1)
pt(x∣x1)和向量场
u
t
(
x
∣
x
1
)
u_t(\pmb x|\pmb x_1)
ut(x∣x1)。
条件概率路径和条件向量场
上面的讨论是通用的,并没有规定条件概率路径和条件向量场的形式。为了简单,作者讨论的是高斯条件概率路径:
p
t
(
x
∣
x
1
)
=
N
(
x
∣
μ
t
(
x
1
)
,
σ
t
(
x
1
)
2
I
)
p_t(\pmb x|\pmb x_1)=\mathcal N(\pmb x| \mu_t(\pmb x_1), \sigma_t(\pmb x_1)^2\pmb I)
pt(x∣x1)=N(x∣μt(x1),σt(x1)2I)其中
μ
0
(
x
1
)
=
0
\mu_0(\pmb x_1)=0
μ0(x1)=0,
σ
0
(
x
1
)
=
1
\sigma_0(\pmb x_1)=1
σ0(x1)=1,
μ
1
(
x
1
)
=
x
1
\mu_1(\pmb x_1)=\pmb x_1
μ1(x1)=x1,
σ
1
(
x
1
)
=
σ
min
\sigma_1(\pmb x_1)=\sigma_{\min}
σ1(x1)=σmin。
有无数的向量场可以产生给定的概率路径,这里作者讨论的是最简单的典型变换。
考虑条件flow:
ψ
t
(
x
)
=
σ
t
(
x
1
)
x
+
μ
t
(
x
1
)
\psi_t(\pmb x)= \sigma_t(\pmb x_1)\pmb x + \mu_t(\pmb x_1)
ψt(x)=σt(x1)x+μt(x1)对应的条件向量场可以通过求解方程得到,并有封闭解:
u
t
(
x
∣
x
1
)
=
σ
t
′
(
x
1
)
σ
t
(
x
1
)
(
x
−
μ
t
(
x
1
)
)
+
μ
t
′
(
x
1
)
u_t(\pmb x|\pmb x_1)=\frac{\sigma'_t(\pmb x_1)}{\sigma_t(\pmb x_1)}(x-\mu_t(\pmb x_1))+\mu'_t(\pmb x_1)
ut(x∣x1)=σt(x1)σt′(x1)(x−μt(x1))+μt′(x1)优化的损失函数是
L
C
F
M
(
θ
)
=
E
t
,
q
(
x
1
)
,
p
(
x
0
)
∥
v
t
(
ψ
t
(
x
0
)
;
θ
)
−
u
t
(
ψ
t
(
x
0
)
∣
x
1
)
∥
2
\mathcal L_{CFM}(\theta)=\mathbb E_{t,q(\pmb x_1),p(\pmb x_0)}\|v_t(\psi_t(\pmb x_0); \theta)-u_t(\psi_t(\pmb x_0)|\pmb x_1)\|^2
LCFM(θ)=Et,q(x1),p(x0)∥vt(ψt(x0);θ)−ut(ψt(x0)∣x1)∥2