Flow-matching |Stable diffusion3 ,Flux.1的理论基础
前置知识:Flow NF
在AIGC领域,Flow、Normalizing Flow和Continuous Normalizing Flow均是重要的生成模型方法。
Flow
Flow流模型旨在通过学习复杂的概率分布来生成数据。其核心思想是通过一系列可逆变换将复杂分布映射到简单分布。这些变换的可逆性使得Flow模型在生成新样本和计算数据的对数似然时都非常高效。
这里的每一个G就是一个可逆变换,理论上只要我们可以精确的学习到每一个G,那么就能准确的估计出真实的分布,从而完成生成任务
Normalizing Flow
Normalizing Flow(规范化流)是一种具体的Flow模型,它通过一系列可逆且可微的变换将复杂数据分布映射到标准正态分布。其主要特点包括:
- 可逆性:每一步变换都是可逆的,这意味着可以从简单分布(如高斯分布)生成样本,并通过逆变换得到复杂分布的样本。
- 变化的可微性:每个变换都是可微的,这使得可以使用梯度下降等优化方法来训练模型。
- 链式法则:通过链式法则,可以有效地计算变换后的数据的对数似然。
Normalizing Flow的常见变种包括RealNVP和Glow等,它们在图像生成任务中表现优异。
CNF 和 Flow-matching
Continuous Normalizing Flow
Continuous Normalizing Flow(连续规范化流,CNF)是Normalizing Flow的连续版本。其核心思想是使用连续时间的微分方程来描述数据分布的变换。具体来说,CNF通过解常微分方程(ODE)来实现连续的概率密度变换。
为了描述CNF,我们首先要定义三个概念:
1)probability density path
p
:
[
0
,
1
]
×
R
d
→
R
>
0
p:[0,1] \times R^d \rightarrow R_{>0}
p:[0,1]×Rd→R>0,由于我们已经对时间t进行归一化处理,所以
t
∈
[
0
,
1
]
t \in [0,1]
t∈[0,1],每一个数据点为
R
d
R^d
Rd,所以概率密度路径就是由无穷多个概率密度函数组成的,在t给定的时候
p
t
:
R
d
→
R
>
0
p_t:R^d \rightarrow R_{>0}
pt:Rd→R>0就是一个概率分布。
2)time-dependent vector field
v
:
[
0
,
1
]
×
R
d
→
R
d
v : [0, 1] \times \mathbb{R}^d \rightarrow \mathbb{R}^d
v:[0,1]×Rd→Rd,
v
t
v_t
vt 就类似于上图中的G,通过对来自上一个分布的采样数据点x进行映射,形成新的概率分布
3)flow
ϕ
t
\phi_t
ϕt,给定t时,是一个流,用于将
p
0
p_0
p0直接映射到
p
t
p_t
pt
p
t
=
[
ϕ
t
]
∗
p
0
p
t
(
x
)
=
[
ϕ
t
]
∗
p
0
(
x
)
=
p
0
(
ϕ
t
−
1
(
x
)
)
det
[
∂
ϕ
t
−
1
∂
x
]
p_t = [\phi_t]_* p_0\\ p_t(x)=[\phi_t]_* p_0(x) = p_0(\phi_t^{-1}(x)) \det \left[ \frac{\partial \phi_t^{-1}}{\partial x} \right]
pt=[ϕt]∗p0pt(x)=[ϕt]∗p0(x)=p0(ϕt−1(x))det[∂x∂ϕt−1]
显然,向量场和流的关系是:
d
d
t
ϕ
t
(
x
)
=
v
t
(
ϕ
t
(
x
)
)
ϕ
0
(
x
)
=
x
\frac{d}{dt} \phi_t(x) = v_t(\phi_t(x))\\ \phi_0(x) = x
dtdϕt(x)=vt(ϕt(x))ϕ0(x)=x
Flow Matching
简单来说FM是一种通过拟合条件概率路径来高效训练CNF的框架和方法
我们令真实的分布为
q
(
x
)
q(x)
q(x),我们设定
p
0
(
x
)
p_0(x)
p0(x)为一个简单的分布,比如像高斯分布,设定
p
1
(
x
)
p_1(x)
p1(x)是一个和
q
(
x
)
q(x)
q(x)近乎等价的分布,我们构造一个这样的概率密度路径,并找到其对应的向量场
u
t
(
x
)
u_t(x)
ut(x),我们就可以得到损失函数
L
FM
(
θ
)
=
E
t
,
p
t
(
x
)
∥
v
t
(
x
)
−
u
t
(
x
)
∥
2
\mathcal{L}_{\text{FM}}(\theta) = \mathbb{E}_{t, p_t(x)} \|v_t(x) - u_t(x)\|^2
LFM(θ)=Et,pt(x)∥vt(x)−ut(x)∥2
这里的
v
t
(
x
)
v_t(x)
vt(x)就是我们的模型输出
但是在实际中这很难被利用,因为符合条件的概率密度路径有很多,我们很难找到一个合适的,容易被计算机模拟和求解的,最重要的是,我们基本上得不到对应于 p t p_t pt的 u t u_t ut的闭式解
于是作者进行了改进,不去直接讨论 p t p_t pt和 u t u_t ut,而是针对每一个特定的样本,去讨论基于给定样本的条件概率路径 p t ( x ∣ x 1 ) p_t(x|x_1) pt(x∣x1)和条件向量场 u t ( x ∣ x 1 ) u_t(x|x_1) ut(x∣x1),以这个为中介,最终得到与 u t u_t ut拟合的很好的 v t v_t vt
通过条件概率路径和向量场来构建目标概率路径以及相应的优化目标
作者提出了一种通过混合条件概率路径来构建目标概率路径的方法:
-
条件概率路径 p t ( x ∣ x 1 ) p_t(x|x_1) pt(x∣x1):
- 给定一个数据样本 x 1 x_1 x1,定义条件概率路径 p t ( x ∣ x 1 ) p_t(x|x_1) pt(x∣x1)。
- 在时间 t = 0 t=0 t=0 时, p 0 ( x ∣ x 1 ) = p ( x ) p_0(x|x_1) = p(x) p0(x∣x1)=p(x),与 x 1 x_1 x1无关,使得 p 0 ( x ) p_0(x) p0(x)为一个独立于 q ( x ) q(x) q(x)的分布
- 在时间 t = 1 t=1 t=1 时, p 1 ( x ∣ x 1 ) p_1(x|x_1) p1(x∣x1) 是一个集中在 x = x 1 x=x_1 x=x1 附近的分布,使得 p 1 ( x ) p_1(x) p1(x)很好的近似 q ( x ) q(x) q(x)
-
边际概率路径 p t ( x ) p_t(x) pt(x):
-
通过对条件概率路径进行边际化得到边际概率路径:
p t ( x ) = ∫ p t ( x ∣ x 1 ) q ( x 1 ) d x 1 p_t(x) = \int p_t(x|x_1) q(x_1) dx_1 pt(x)=∫pt(x∣x1)q(x1)dx1 -
构造合适的 p 1 ( x ∣ x 1 ) p_1(x|x_1) p1(x∣x1),使得在时间 t = 1 t=1 t=1 时,边际概率 p 1 ( x ) p_1(x) p1(x) 可以很好地近似数据分布 q q q:
p 1 ( x ) = ∫ p 1 ( x ∣ x 1 ) q ( x 1 ) d x 1 ≈ q ( x ) p_1(x) = \int p_1(x|x_1) q(x_1) dx_1 \approx q(x) p1(x)=∫p1(x∣x1)q(x1)dx1≈q(x)
-
-
边际向量场 u t ( x ) u_t(x) ut(x):
-
通过对条件向量场进行边际化定义边际向量场:
u t ( x ) = ∫ u t ( x ∣ x 1 ) p t ( x ∣ x 1 ) q ( x 1 ) p t ( x ) d x 1 u_t(x) = \int \frac{u_t(x|x_1) p_t(x|x_1) q(x_1)}{p_t(x)} dx_1 ut(x)=∫pt(x)ut(x∣x1)pt(x∣x1)q(x1)dx1 -
条件向量场 u t ( ⋅ ∣ x 1 ) u_t(\cdot|x_1) ut(⋅∣x1) 生成条件概率路径 p t ( ⋅ ∣ x 1 ) p_t(\cdot|x_1) pt(⋅∣x1)
-
-
定理 1:
- 给定生成条件概率路径 p t ( x ∣ x 1 ) p_t(x|x_1) pt(x∣x1) 的向量场 u t ( x ∣ x 1 ) u_t(x|x_1) ut(x∣x1),对于任意分布 q ( x 1 ) q(x_1) q(x1),定义的边际向量场 u t u_t ut 会生成定义的边际概率路径 p t p_t pt。
这就意味着,我们可以用条件向量场聚合成真正的边际向量场,并得到边际概率密度路径。将问题转化为拟合简单的条件向量场。
但是即便我们得到了条件向量场,由于(6)式子定义中包含了
q
(
x
1
)
q(x_1)
q(x1),所以还是不能得到
u
t
(
x
)
u_t(x)
ut(x),于是作者定义了一个更简单的目标函数
L
CFM
(
θ
)
=
E
t
,
q
(
x
1
)
,
p
t
(
x
∣
x
1
)
[
∥
v
t
(
x
)
−
u
t
(
x
∣
x
1
)
∥
2
]
L_{\text{CFM}}(\theta) = \mathbb{E}_{t, q(x_1), p_t(x|x_1)} \left[ \left\| v_t(x) - u_t(x|x_1) \right\|^2 \right]
LCFM(θ)=Et,q(x1),pt(x∣x1)[∥vt(x)−ut(x∣x1)∥2]
这个目标可以通过有效采样
p
t
(
x
∣
x
1
)
p_t(x|x_1)
pt(x∣x1) 和计算
u
t
(
x
∣
x
1
)
u_t(x|x_1)
ut(x∣x1) 来实现,当给定样本时,这两项的获得都会变得很容易,后文中会继续探讨。作者通过证明定理2,说明了优化(7)就等价于优化(3),这个论证是十分关键的
定理 2: 假设 p t ( x ) > 0 p_t(x) > 0 pt(x)>0 对所有 x ∈ R d x \in \mathbb{R}^d x∈Rd 和 t ∈ [ 0 , 1 ] t \in [0, 1] t∈[0,1] 成立,则 L CFM L_{\text{CFM}} LCFM 和 L FM L_{\text{FM}} LFM 除了一个与参数 θ \theta θ 无关的常数外是相等的,即 ∇ θ L FM ( θ ) = ∇ θ L CFM ( θ ) \nabla_\theta L_{\text{FM}}(\theta) = \nabla_\theta L_{\text{CFM}}(\theta) ∇θLFM(θ)=∇θLCFM(θ)。
自此,我们需要找这样一个条件概率路径(假设 x 1 x_1 x1是给定的一个样本点)
- p 0 ( x ∣ x 1 ) p_0(x|x_1) p0(x∣x1)与 x 1 x_1 x1无关,且接近于一个简单分布
- p 1 ( x ∣ x 1 ) p_1(x|x_1) p1(x∣x1) 是一个集中在 x = x 1 x=x_1 x=x1 附近的分布
- 这个条件概率路径对应的条件向量场是可求的
高斯条件概率路径
高斯条件概率路径就是一个符合上述要求的路径:
p
t
(
x
∣
x
1
)
=
N
(
x
∣
μ
t
(
x
1
)
,
σ
t
(
x
1
)
2
I
)
p_t(x|x_1) = \mathcal{N}(x | \mu_t(x_1), \sigma_t(x_1)^2 I)
pt(x∣x1)=N(x∣μt(x1),σt(x1)2I)
其中
μ
0
(
x
1
)
=
0
\mu_0(x_1) = 0
μ0(x1)=0
σ
0
(
x
1
)
=
1
\sigma_0(x_1) = 1
σ0(x1)=1 且
μ
1
(
x
1
)
=
x
1
\mu_1(x_1) = x_1
μ1(x1)=x1
σ
1
(
x
1
)
=
σ
min
\sigma_1(x_1) = \sigma_{\min}
σ1(x1)=σmin
对应的一个流就是
ϕ
t
(
x
)
=
σ
t
(
x
1
)
x
+
μ
t
(
x
1
)
\phi_t(x) = \sigma_t(x_1)x + \mu_t(x_1)
ϕt(x)=σt(x1)x+μt(x1)
其中x服从标准高斯分布
根据公式
d
d
t
ψ
t
(
x
)
=
u
t
(
ψ
t
(
x
)
∣
x
1
)
\frac{d}{dt} \psi_t(x) = u_t(\psi_t(x)|x_1)
dtdψt(x)=ut(ψt(x)∣x1),可以求得
u
t
(
x
∣
x
1
)
=
σ
t
′
(
x
1
)
σ
t
(
x
1
)
(
x
−
μ
t
(
x
1
)
)
+
μ
t
′
(
x
1
)
u_t(x|x_1) = \frac{\sigma'_t(x_1)}{\sigma_t(x_1)} (x - \mu_t(x_1)) + \mu'_t(x_1)
ut(x∣x1)=σt(x1)σt′(x1)(x−μt(x1))+μt′(x1)
可以看出,这个条件向量场是可求的,那么我们就能通过优化公式(7)来达到训练目标
Diffusion
在这一小节,作者用flow-matching的思想统一了之前的score-matching和 DDPM
score-matching
原始score-matching的做法是给原始数据上加不同程度的噪声,采样阶段按照噪声强度从大到小的顺序进行采样,不同程度的噪声对应着不同的概率分布,这些概率分布共同构成了概率密度路径,将这些概率分布连续化就得到了
p
t
(
x
)
=
N
(
x
∣
x
1
,
σ
1
−
t
2
I
)
w
h
e
r
e
σ
t
i
s
a
n
i
n
c
r
e
a
s
i
n
g
f
u
n
c
t
i
o
n
,
σ
0
=
0
a
n
d
σ
1
>
>
1
p_t(x) = \mathcal{N}(x|x_1, \sigma_{1-t}^2 I)\\ where \ \sigma_t\ is\ an\ increasing\ function\ ,\sigma_0=0\ and\ \sigma_1>>1
pt(x)=N(x∣x1,σ1−t2I)where σt is an increasing function ,σ0=0 and σ1>>1
所以
μ
t
(
x
1
)
=
x
1
,
σ
t
(
x
1
)
=
σ
1
−
t
\mu_t(x_1)=x_1,\sigma_t(x_1)=\sigma_{1-t}
μt(x1)=x1,σt(x1)=σ1−t,与之对应的
u
t
(
x
∣
x
1
)
=
−
σ
1
−
t
′
σ
1
−
t
(
x
−
x
1
)
u_t(x|x_1) = -\frac{\sigma'_{1-t}}{\sigma_{1-t}}(x - x_1)
ut(x∣x1)=−σ1−tσ1−t′(x−x1)
DDPM
原始DDPM的加噪公式为
q
(
x
t
∣
x
0
)
=
N
(
x
t
;
α
ˉ
t
x
0
,
(
1
−
α
ˉ
t
)
I
)
q(x_t | x_0) = \mathcal{N} \left( x_t; \sqrt{\bar{\alpha}_t} x_0, (1 - \bar{\alpha}_t) I \right)
q(xt∣x0)=N(xt;αˉtx0,(1−αˉt)I),将其连续化后得到如下公式
p
t
(
x
∣
x
1
)
=
N
(
x
∣
α
1
−
t
x
1
,
(
1
−
α
1
−
t
2
)
I
)
w
h
e
r
e
α
t
=
e
−
1
2
T
(
t
)
,
T
(
t
)
=
∫
0
t
β
(
s
)
d
s
p_t(x|x_1) = \mathcal{N}(x|\alpha_{1-t} x_1, (1 - \alpha_{1-t}^2) I)\\ where\ \alpha_t = e^{-\frac{1}{2} T(t)}, T(t) = \int_0^t \beta(s) ds
pt(x∣x1)=N(x∣α1−tx1,(1−α1−t2)I)where αt=e−21T(t),T(t)=∫0tβ(s)ds
所以
μ
t
(
x
1
)
=
α
1
−
t
x
1
a
n
d
σ
t
(
x
1
)
=
1
−
α
1
−
t
2
\mu_t(x_1) = \alpha_{1-t} x_1\ and \ \sigma_t(x_1) = \sqrt{1 - \alpha_{1-t}^2}
μt(x1)=α1−tx1 and σt(x1)=1−α1−t2,与之对应的条件向量场即为
u
t
(
x
∣
x
1
)
=
α
1
−
t
′
1
−
α
1
−
t
2
(
α
1
−
t
x
−
x
1
)
=
−
T
′
(
1
−
t
)
2
[
e
−
T
(
1
−
t
)
x
−
e
−
1
2
T
(
1
−
t
)
x
1
1
−
e
−
T
(
1
−
t
)
]
u_t(x|x_1) = \frac{\alpha'_{1-t}}{1 - \alpha_{1-t}^2} (\alpha_{1-t} x - x_1) = -\frac{T'(1-t)}{2} \left[ \frac{e^{-T(1-t)} x - e^{-\frac{1}{2} T(1-t)} x_1}{1 - e^{-T(1-t)}} \right]
ut(x∣x1)=1−α1−t2α1−t′(α1−tx−x1)=−2T′(1−t)[1−e−T(1−t)e−T(1−t)x−e−21T(1−t)x1]
用Flow Matching的重述DDPM和score-matching,不仅可以提供一个全新的视角,更重要的是也会使得模型训练变得更稳定
Optimal Transport
既然符合条件的概率密度路径有无穷多个,那么对于其最自然的选择就是使均值和方差随时间线性变化
μ
t
(
x
)
=
t
x
1
,
σ
t
(
x
)
=
1
−
(
1
−
σ
min
)
t
\mu_t(x) = t x_1, \quad \sigma_t(x) = 1 - (1 - \sigma_{\min}) t
μt(x)=tx1,σt(x)=1−(1−σmin)t
其对应的条件向量场和流分别是
u
t
(
x
∣
x
1
)
=
x
1
−
(
1
−
σ
min
)
x
1
−
(
1
−
σ
min
)
t
u_t(x|x_1) = \frac{x_1 - (1 - \sigma_{\min}) x}{1 - (1 - \sigma_{\min}) t}
ut(x∣x1)=1−(1−σmin)tx1−(1−σmin)x
ψ t ( x ) = ( 1 − ( 1 − σ min ) t ) x + t x 1 \psi_t(x) = (1 - (1 - \sigma_{\min}) t) x + t x_1 ψt(x)=(1−(1−σmin)t)x+tx1
事实上,这样的条件流实际上是两个高斯分布 p 0 ( x ∣ x 1 ) p_0(x∣x1) p0(x∣x1)和 p 1 ( x ∣ x 1 ) p_1(x∣x1) p1(x∣x1)之间的最优传输的位移映射(Optimal Transport displacement map)
目标函数此时就成为
L
CFM
(
θ
)
=
E
t
,
q
(
x
1
)
,
p
(
x
0
)
∥
v
t
(
ψ
t
(
x
0
)
)
−
(
x
1
−
(
1
−
σ
min
)
x
0
)
∥
2
L_{\text{CFM}} (\theta) = \mathbb{E}_{t, q(x_1), p(x_0)} \left\| v_t(\psi_t(x_0)) - \left( {x_1 - (1 - \sigma_{\min}) x_0} \right) \right\|^2
LCFM(θ)=Et,q(x1),p(x0)∥vt(ψt(x0))−(x1−(1−σmin)x0)∥2
最开始看到这个公式很震惊很疑惑🤔,因为当
σ
m
i
n
\sigma_{min}
σmin最够小的时候,拟合目标等价为
x
1
−
x
0
x_1-x_0
x1−x0,这不就是噪声
ϵ
\epsilon
ϵ吗,兜兜转转又回到了最初的DDPM?细想下来还是有区别的,DDPM中拟合的是真实图片分布和众多真实分布加噪声之后形成的概率分布之间的噪声,Optimal Transport似乎暴力,只去考虑了真实图片分布和标准高斯分布之间的噪声。但是作者的实验最终证明了这反倒可以使模型获得更好的性能。
模型训练完成之后,就可以从一个标准高斯分布中采样,然后一步步调用 v t ( x ) v_t(x) vt(x)来生成图片,显然,当条件样本 x 1 x_1 x1固定时,OT的向量场是直的,DDPM的向量场是弯曲的,这也从侧面应证了在DDPM中强行跳步会导致图片质量下降严重
OT的设计方向是,在给定样本点时,向量场是直的,但是在模型训练中,不可能只针对单个点进行拟合,那么众多样本点的条件向量场肯定会有重叠,模型在重叠部分中的拟合结果会变得杂乱无章,所以最终采样时走的路径也不是标准的直线,那么能否有办法,在条件向量场是直的的情况下,使得最终的 v t ( x ) v_t(x) vt(x)也能始终保持是直的呢
Rectified Flow
Rectified Flow是Flow Matching基础上的一种改进,Flow Matching是设计点源场为直线,而Rectified Flow则是让最后的叠加场为直线,从而使最终采样时走直线,因此最后形式上更加简单,也因为采样是走直线,所以可以一步生成
为了说明rectified flow的原理,我们考虑最简单的情况,假设我们的学习任务是将蓝色表示的概率分布映射到红色表示的概率分布,那么如果用只用OT进行训练,那么在中间地带的向量场就会紊乱(左图),rectified flow的工作就是通过重新配对训练样本点的方式,在原训练的基础上再次训练,获得右图的向量场
如果想进一步探究Rectified Flow的技术细节的话,可以查看RF的官方推文