SIMPLIFYING, STABILIZING & SCALING CONTINUOUS-TIME CONSISTENCY MODELS
原文链接https://arxiv.org/abs/2410.11081
简化、稳定和扩展连续时间一致性模型
摘要
一致性模型(CMs)是一类强大的基于扩散的生成模型,它们针对快速采样进行了优化。大多数现有的CMs都是使用离散化的时间步长进行训练的,这引入了额外的超参数,并且容易受到离散化误差的影响。虽然连续时间表述可以减轻这些问题,但由于训练不稳定性,它们的成功受到了限制。为了解决这个问题,我们提出了一个简化的理论框架,它统一了扩散模型和CMs的先前参数化,识别了不稳定性的根本原因。基于这个分析,我们引入了在扩散过程参数化、网络架构和训练目标方面的关键改进。这些变化使我们能够以前所未有的规模训练连续时间CMs,达到了在ImageNet 512×512上的1.5B参数。我们提出的训练算法,仅使用两个采样步骤,就实现了在CIFAR-10上的FID得分为2.06,在ImageNet 64×64上的FID得分为1.48,在ImageNet 512×512上的FID得分为1.88,将与最佳现有扩散模型的FID得分差距缩小到10%以内。
引言
1 引言:
在图1中,我们比较了不同模型在ImageNet 512×512上的质量,通过FID(越低越好)来衡量。我们的两步sCM实现了与最佳先前生成模型相当的样本质量,同时使用的计算量不到10%的有效采样计算。
扩散模型已经彻底改变了生成性AI,在图像、3D、音频和视频生成方面取得了显著的成果。尽管它们取得了成功,但一个显著的缺点是采样速度慢,通常需要几十到几百步才能生成一个样本。已经提出了各种扩散蒸馏技术,包括直接蒸馏、对抗性蒸馏、渐进式蒸馏和变分得分蒸馏。然而,这些方法都带来了挑战:直接蒸馏由于需要大量的扩散模型样本而造成了巨大的计算成本;对抗性蒸馏引入了与GAN训练相关的复杂性;渐进式蒸馏需要多个训练阶段,对于一两步生成效果较差;而VSD可能会产生过于平滑的样本,多样性有限,在高引导水平下表现挣扎。
一致性模型(CMs)在解决这些问题上具有显著优势。它们消除了对扩散模型样本的监督需求,避免了生成合成数据集的计算成本。CMs还绕过了对抗性训练,避开了其固有的困难。除了蒸馏,CMs可以从头开始通过一致性训练(CT)进行训练,不依赖于预训练的扩散模型。先前的工作已经展示了CMs在少步骤生成,尤其是在一两步生成中的有效性。然而,这些结果都是基于离散时间CMs,这引入了离散化误差,并要求仔细安排时间步长网格,可能导致样本质量不是最优。相比之下,连续时间CMs避免了这些问题,但在训练稳定性方面面临挑战。在本文中,我们引入了简化、稳定和扩展连续时间CMs训练的技术。我们的第一个贡献是TrigFlow,这是一种新的表述,它统一了EDM和Flow Matching,显著简化了扩散模型、相关概率流ODE和CMs的表述。在此基础上,我们分析了CM训练中的不稳定性的根本原因,并提出了完整的解决方案。我们的方法包括改进的时间条件和自适应组归一化在网络架构内的应用。此外,我们重新制定了连续时间CMs的训练目标,纳入了关键术语的自适应加权和归一化,以及渐进式退火,以实现稳定和可扩展的训练。
通过这些改进,我们在一致性训练和蒸馏方面提升了一致性模型的性能,达到了与之前离散时间表述相当甚至更好的结果。我们的模型,称为sCMs,展示了在不同数据集和模型尺寸上的成功。我们在CIFAR-10、ImageNet 64×64和ImageNet 512×512上训练sCMs,达到了前所未有的规模,拥有15亿参数——迄今为止训练的最大CMs(见图2中的样本)。我们展示了sCMs随着计算量的增加,以可预测的方式提升样本质量。此外,与需要更多采样计算的最先进的扩散模型相比,sCMs在使用两步生成时将FID差距缩小到10%以内。此外,我们通过展示样本质量随着相邻时间步长间隔的缩小而提高,接近连续时间极限,为连续时间CMs相对于离散时间变体的优势提供了严格的论证。此外,我们还考察了sCMs和VSD之间的差异,发现sCMs产生了更多样化的样本,并且与引导更为兼容,而VSD在更高的引导水平上往往表现不佳。
2.准备工作
2.1扩散模型
给定一个训练数据集,设 p d p_d pd表示其底层数据分布, σ d \sigma_d σd表示其标准差。扩散模型通过学习逆转一个逐步将数据样本 x 0 ∼ p d x_0 \sim p_d x0∼pd扰动为带噪版本 x t = α t x 0 + σ t z x_t = \alpha_t x_0 + \sigma_t z xt=αtx0+σtz的过程来生成样本,其中 z ∼ N ( 0 , I ) z \sim N(0, I) z∼N(0,I)是标准高斯噪声。这个扰动随着 t ∈ [ 0 , T ] t \in [0, T] t∈[0,T]的增加而增加,其中较大的 t t t表示更大的噪声。我们考虑两种最近的扩散模型表述:
EDM (Karras et al., 2022; 2024):
扰动过程简单地设置
α
t
=
1
\alpha_t = 1
αt=1和
σ
t
=
t
\sigma_t = t
σt=t。训练目标由下式给出:
L
Diff
(
θ
)
=
E
x
0
,
z
,
t
[
w
(
t
)
∥
f
θ
(
x
t
,
t
)
−
x
0
∥
2
2
]
,
L_{\text{Diff}}(\theta) = \mathbb{E}_{x_0, z, t} \left[ w(t) \left\| f_\theta(x_t, t) - x_0 \right\|^2_2 \right],
LDiff(θ)=Ex0,z,t[w(t)∥fθ(xt,t)−x0∥22],其中
w
(
t
)
w(t)
w(t)是一个权重函数。
扩散模型参数化为:
f
θ
(
x
t
,
t
)
=
c
skip
(
t
)
x
t
+
c
out
(
t
)
F
θ
(
c
in
(
t
)
x
t
,
c
noise
(
t
)
)
,
f_\theta(x_t, t) = c_{\text{skip}}(t) x_t + c_{\text{out}}(t) F_\theta(c_{\text{in}}(t) x_t, c_{\text{noise}}(t)),
fθ(xt,t)=cskip(t)xt+cout(t)Fθ(cin(t)xt,cnoise(t)),
其中
F
θ
F_\theta
Fθ是具有参数
θ
\theta
θ的神经网络,
c
skip
c_{\text{skip}}
cskip,
c
out
c_{\text{out}}
cout,
c
in
c_{\text{in}}
cin, 和
c
noise
c_{\text{noise}}
cnoise是手动设计的系数,确保训练目标在初始化时在所有时间步上具有单位方差。采样过程通过从
x
T
∼
N
(
0
,
T
2
I
)
x_T \sim N(0, T^2 I)
xT∼N(0,T2I)开始,解决概率流ODE (PF-ODE) 定义为:
d
x
t
d
t
=
x
t
−
f
θ
(
x
t
,
t
)
t
,
\frac{d x_t}{d t} = \frac{x_t - f_\theta(x_t, t)}{t},
dtdxt=txt−fθ(xt,t),从
t
=
T
t = T
t=T到
t
=
0
t = 0
t=0。
Flow Matching:
扰动过程使用可微分系数
α
t
\alpha_t
αt和
σ
t
\sigma_t
σt,其时间导数表示为
α
t
′
\alpha'_t
αt′和
σ
t
′
\sigma'_t
σt′(通常,
α
t
=
1
−
t
\alpha_t = 1 - t
αt=1−t和
σ
t
=
t
\sigma_t = t
σt=t)。 训练目标由下式给出:
L
Diff
(
θ
)
=
E
x
0
,
z
,
t
[
w
(
t
)
∥
F
θ
(
x
t
,
t
)
−
(
α
t
′
x
0
+
σ
t
′
z
)
∥
2
2
]
,
L_{\text{Diff}}(\theta) = \mathbb{E}_{x_0, z, t} \left[ w(t) \left\| F_\theta(x_t, t) - (\alpha'_t x_0 + \sigma'_t z) \right\|^2_2 \right],
LDiff(θ)=Ex0,z,t[w(t)∥Fθ(xt,t)−(αt′x0+σt′z)∥22],
其中
w
(
t
)
w(t)
w(t)是一个权重函数,
F
θ
F_\theta
Fθ是由
θ
\theta
θ参数化的神经网络。 采样过程从
t
=
1
t = 1
t=1开始,
x
1
∼
N
(
0
,
I
)
x_1 \sim N(0, I)
x1∼N(0,I)并解决概率流ODE (PF-ODE),定义为:
d
x
t
d
t
=
F
θ
(
x
t
,
t
)
,
\frac{d x_t}{d t} = F_\theta(x_t, t),
dtdxt=Fθ(xt,t),从
t
=
1
t = 1
t=1到
t
=
0
t = 0
t=0。
2.2一致性模型
一致性模型(CM)是一个神经网络 f θ ( x t , t ) f_\theta(x_t, t) fθ(xt,t),它被训练以直接将带噪输入 x t x_t xt映射到相应的干净数据 x 0 x_0 x0,通过遵循从 x t x_t xt开始的PF-ODE的采样轨迹。一个有效的 f θ f_\theta fθ必须满足边界条件 f θ ( x , 0 ) ≡ x f_\theta(x, 0) \equiv x fθ(x,0)≡x。为了满足这个条件,我们可以参数化一致性模型为 f θ ( x t , t ) = c skip ( t ) x t + c out ( t ) F θ ( c in ( x t ) , c noise ( t ) ) f_\theta(x_t, t) = c_{\text{skip}}(t) x_t + c_{\text{out}}(t) F_\theta(c_{\text{in}}(x_t), c_{\text{noise}}(t)) fθ(xt,t)=cskip(t)xt+cout(t)Fθ(cin(xt),cnoise(t)),其中 c skip ( 0 ) = 1 c_{\text{skip}}(0) = 1 cskip(0)=1和 c out ( 0 ) = 0 c_{\text{out}}(0) = 0 cout(0)=0。
CMs被训练以在相邻时间步上具有一致的输出。根据如何选择接近的时间步,存在两种类型的一致性模型:
离散时间CMs:
训练目标在两个相邻时间步上定义,具有有限距离:
L
CM
(
θ
)
=
E
x
t
,
t
[
w
(
t
)
d
(
f
θ
(
x
t
,
t
)
,
f
θ
(
x
t
−
Δ
t
,
t
−
Δ
t
)
)
]
,
L_{\text{CM}}(\theta) = \mathbb{E}_{x_t, t} \left[ w(t) d(f_\theta(x_t, t), f_\theta(x_{t-\Delta t}, t - \Delta t)) \right],
LCM(θ)=Ext,t[w(t)d(fθ(xt,t),fθ(xt−Δt,t−Δt))], (1)
其中
θ
−
\theta^-
θ−表示
θ
\theta
θ的停止梯度,
w
(
t
)
w(t)
w(t)是权重函数,
Δ
t
>
0
\Delta t > 0
Δt>0是相邻时间步之间的距离,
d
(
⋅
,
⋅
)
d(\cdot, \cdot)
d(⋅,⋅)是度量函数;常见的选择包括
ℓ
2
\ell_2
ℓ2损失
d
(
x
,
y
)
=
∥
x
−
y
∥
2
2
d(x, y) = \| x - y \|_2^2
d(x,y)=∥x−y∥22,伪Huber损失
d
(
x
,
y
)
=
∥
x
−
y
∥
2
2
+
c
2
−
c
d(x, y) = \sqrt{\|x - y\|_2^2 + c^2} - c
d(x,y)=∥x−y∥22+c2−c对于
c
>
0
c > 0
c>0,和LPIPS损失(Zhang等人,2018)。离散时间CMs对
Δ
t
\Delta t
Δt的选择很敏感,因此需要手动设计的退火时间表(Song & Dhariwal, 2023; Geng等人,2024)以实现快速收敛。带噪样本
x
t
−
Δ
t
x_{t-\Delta t}
xt−Δt在前一个时间步
t
−
Δ
t
t - \Delta t
t−Δt通常通过使用步长
Δ
t
\Delta t
Δt的数值ODE求解器从
x
t
x_t
xt获得,这可能导致额外的离散化误差。
连续时间CMs:
当使用
d
(
x
,
y
)
=
∥
x
−
y
∥
2
2
d(x, y) = \| x - y \|_2^2
d(x,y)=∥x−y∥22并取
Δ
t
→
0
\Delta t \to 0
Δt→0的极限时,Song等人(2023)展示了方程(1)相对于
θ
\theta
θ的梯度收敛到:
∇
θ
E
x
t
,
t
[
−
w
(
t
)
f
θ
T
(
x
t
,
t
)
d
f
θ
−
(
x
t
,
t
)
d
t
]
,
\nabla_\theta \mathbb{E}_{x_t, t} \left[ - w(t) f_\theta^T(x_t, t) \frac{d f_\theta^-(x_t, t)}{d t} \right],
∇θExt,t[−w(t)fθT(xt,t)dtdfθ−(xt,t)], (2)
其中
d
f
θ
−
(
x
t
,
t
)
d
t
=
∇
x
t
f
θ
−
(
x
t
,
t
)
d
x
t
d
t
+
∂
f
θ
−
(
x
t
,
t
)
∂
t
\frac{d f_\theta^-(x_t, t)}{d t} = \nabla_{x_t} f_\theta^-(x_t, t) \frac{d x_t}{d t} + \frac{\partial f_\theta^-(x_t, t)}{\partial t}
dtdfθ−(xt,t)=∇xtfθ−(xt,t)dtdxt+∂t∂fθ−(xt,t)是
f
θ
−
f_\theta^-
fθ−在
(
x
t
,
t
)
(x_t, t)
(xt,t)处沿PF-ODE轨迹的切线。值得注意的是,连续时间CMs不依赖于ODE求解器,这避免了离散化误差并提供了更准确的训练监督信号。然而,先前的工作发现,训练连续时间CMs,甚至离散时间CMs与非常小的
Δ
t
\Delta t
Δt,在优化上存在严重的不稳定性。这极大地限制了连续时间CMs的实证性能和采用。
一致性蒸馏和一致性训练:
离散时间CMs和连续时间CMs都可以使用一致性蒸馏(CD)或一致性训练(CT)进行训练。在一致性蒸馏中,通过从预训练的扩散模型中提取知识来训练CM。这个扩散模型提供了可以直接插入方程(2)进行连续时间CMs训练的PF-ODE。此外,通过数值求解PF-ODE从
x
t
x_t
xt获得
x
t
−
Δ
t
x_{t-\Delta t}
xt−Δt,也可以通过方程(1)训练离散时间CMs。相比之下,一致性训练(CT)从头开始训练CMs,无需预训练的扩散模型,这确立了CMs作为独立的生成模型家族。具体来说,CT近似离散时间CMs中的
x
t
−
Δ
t
x_{t-\Delta t}
xt−Δt为
x
t
−
Δ
t
=
α
t
−
Δ
t
x
0
+
σ
t
−
Δ
t
z
x_{t-\Delta t} = \alpha_{t-\Delta t} x_0 + \sigma_{t-\Delta t} z
xt−Δt=αt−Δtx0+σt−Δtz,重用相同的数据
x
0
x_0
x0和噪声
z
z
z来采样
x
t
=
α
t
x
0
+
σ
t
z
x_t = \alpha_t x_0 + \sigma_t z
xt=αtx0+σtz。在连续时间极限中,随着
Δ
t
→
0
\Delta t \to 0
Δt→0,这种方法产生了PF-ODE
d
x
t
d
t
→
α
t
′
x
0
+
σ
t
′
z
\frac{d x_t}{d t} \to \alpha'_t x_0 + \sigma'_t z
dtdxt→αt′x0+σt′z的无偏估计,导致训练连续时间CMs的方程(2)的无偏估计。
为了简化EDM和随后的CMs,我们提出了TrigFlow,这是一种扩散模型的表述,它保持了EDM的属性,但具有更简单的系数关系。TrigFlow是流匹配(也称为随机插值或修正流)和v-预测参数化的特殊情况,并且它结合了两种表述的优势,同时允许扩散过程、扩散模型参数化、PF-ODE、扩散训练目标和CM参数化都具有简单的表达式。
3.简化连续时间一致性模型
TrigFlow扩散过程:
给定
x
0
∼
p
d
(
x
0
)
x_0 \sim p_d(x_0)
x0∼pd(x0)和
z
∼
N
(
0
,
σ
d
2
I
)
z \sim N(0, \sigma_d^2 I)
z∼N(0,σd2I),带噪样本定义为
x
t
=
cos
(
t
)
x
0
+
sin
(
t
)
z
x_t = \cos(t) x_0 + \sin(t) z
xt=cos(t)x0+sin(t)z对于
t
∈
[
0
,
π
2
]
t \in [0, \frac{\pi}{2}]
t∈[0,2π]。作为一个特殊情况,先验样本
x
π
2
∼
N
(
0
,
σ
d
2
I
)
x_{\frac{\pi}{2}} \sim N(0, \sigma_d^2 I)
x2π∼N(0,σd2I)。
TrigFlow扩散模型和PF-ODE:
我们参数化扩散模型为:
f
θ
(
x
t
,
t
)
=
F
θ
(
x
t
σ
d
,
c
noise
(
t
)
)
,
f_\theta(x_t, t) = F_\theta\left(\frac{x_t}{\sigma_d}, c_{\text{noise}}(t)\right),
fθ(xt,t)=Fθ(σdxt,cnoise(t)),
其中
F
θ
F_\theta
Fθ是具有参数
θ
\theta
θ的神经网络,( c_{\text{noise}}(t)$是t的变换以便于时间条件化。相应的PF-ODE由下式给出:
d
x
t
d
t
=
σ
d
F
θ
(
x
t
σ
d
,
c
noise
(
t
)
)
.
\frac{d x_t}{d t} = \sigma_d F_\theta\left(\frac{x_t}{\sigma_d}, c_{\text{noise}}(t)\right).
dtdxt=σdFθ(σdxt,cnoise(t)).(3)
TrigFlow扩散目标:
在TrigFlow中,扩散模型通过最小化以下目标进行训练:
L
Diff
(
θ
)
=
E
x
0
,
z
,
t
[
σ
d
2
∥
F
θ
(
x
t
σ
d
,
c
noise
(
t
)
)
−
v
t
∥
2
2
]
,
L_{\text{Diff}}(\theta) = \mathbb{E}_{x_0, z, t} \left[ \frac{\sigma_d}{2} \left\| F_\theta\left(\frac{x_t}{\sigma_d}, c_{\text{noise}}(t)\right) - v_t \right\|_2^2 \right],
LDiff(θ)=Ex0,z,t[2σd
Fθ(σdxt,cnoise(t))−vt
22],(4)
其中
v
t
=
cos
(
t
)
z
−
sin
(
t
)
x
0
v_t = \cos(t) z - \sin(t) x_0
vt=cos(t)z−sin(t)x0是训练目标。
TrigFlow一致性模型:
如第2.2节中提到的,一个有效的CM必须满足边界条件
f
θ
(
x
,
0
)
≡
x
f_\theta(x, 0) \equiv x
fθ(x,0)≡x。为了执行这个条件,我们参数化CM为PF-ODE的单步解:
f
θ
(
x
t
,
t
)
=
cos
(
t
)
x
t
−
sin
(
t
)
σ
d
F
θ
(
x
t
σ
d
,
c
noise
(
t
)
)
,
f_\theta(x_t, t) = \cos(t) x_t - \sin(t) \sigma_d F_\theta\left(\frac{x_t}{\sigma_d}, c_{\text{noise}}(t)\right),
fθ(xt,t)=cos(t)xt−sin(t)σdFθ(σdxt,cnoise(t)),(5)
其中
c
noise
(
t
)
c_{\text{noise}}(t)
cnoise(t)是时间变换,我们将在第4.1节中讨论。
4.稳定连续时间一致性模型
训练连续时间CMs非常不稳定。因此,与之前的工作中的离散时间CMs相比,它们的表现明显更差。为了解决这个问题,我们建立在TrigFlow框架的基础上,并引入了几个理论上有动机的改进,以稳定连续时间CMs,重点是参数化、网络架构和训练目标。
4.1参数化和网络架构
训练连续时间CMs的关键方程是方程(2),它依赖于切线函数
d
f
θ
−
(
x
t
,
t
)
d
t
\frac{d f_{\theta}^{-}(x_t, t)}{d t}
dtdfθ−(xt,t)。在TrigFlow表述下,这个切线函数由下式给出:
d
f
θ
−
(
x
t
,
t
)
d
t
=
−
cos
(
t
)
(
σ
d
F
θ
−
(
x
t
σ
d
,
t
)
−
d
x
t
d
t
)
−
sin
(
t
)
(
x
t
+
σ
d
d
F
θ
−
d
t
)
.
\frac{d f_{{\theta}^{-}}(x_t, t)}{d t} = -\cos(t) \left( \sigma_d F_{{\theta}^{-}} \left( \frac{x_t}{\sigma_d}, t \right) - \frac{d x_t}{d t} \right) - \sin(t) \left( x_t + \sigma_d \frac{d F_{{\theta}^{-}}}{d t} \right).
dtdfθ−(xt,t)=−cos(t)(σdFθ−(σdxt,t)−dtdxt)−sin(t)(xt+σddtdFθ−).(6)
为了稳定训练,需要确保不同时间步上的切线函数是稳定的。我们发现
σ
d
F
θ
−
\sigma_d F_{{\theta}^{-}}
σdFθ−,PF-ODE
d
x
t
d
t
\frac{d x_t}{d t}
dtdxt,和带噪样本
x
t
x_t
xt 都是相对稳定的。切线函数中唯一剩下的项是
sin
(
t
)
∂
F
θ
−
∂
t
\sin(t) \frac{\partial F_{{\theta}^{-}}}{\partial t}
sin(t)∂t∂Fθ−,它可以分解为:
sin
(
t
)
∂
F
θ
−
∂
t
=
sin
(
t
)
∂
emb
(
c
noise
(
t
)
)
∂
c
noise
(
t
)
∂
F
θ
−
∂
emb
(
c
noise
(
t
)
)
.
\sin(t) \frac{\partial F_{{\theta}^{-}}}{\partial t} = \sin(t) \frac{\partial \text{emb}(c_{\text{noise}}(t))}{\partial c_{\text{noise}}(t)} \frac{\partial F_{{\theta}^{-}}}{\partial \text{emb}(c_{\text{noise}}(t))}.
sin(t)∂t∂Fθ−=sin(t)∂cnoise(t)∂emb(cnoise(t))∂emb(cnoise(t))∂Fθ−.(7)
以下是我们提出的改进措施:
-
恒等时间变换(Identity Time Transformation):
我们提出使用 c noise ( t ) = t c_{\text{noise}}(t) = t cnoise(t)=t 作为默认的时间变换,以减轻数值不稳定性。 -
位置时间嵌入(Positional Time Embeddings):
我们使用位置嵌入,这相当于在傅里叶嵌入中 s ≈ 0.02 s \approx 0.02 s≈0.02,以避免由于傅里叶尺度 s s s 较大引起的不稳定性。 -
自适应双重归一化(Adaptive Double Normalization):
我们提出自适应双重归一化,定义为 y = norm ( x ) ⊙ pnorm ( s ( t ) ) + pnorm ( b ( t ) ) y = \text{norm}(x) \odot \text{pnorm}(s(t)) + \text{pnorm}(b(t)) y=norm(x)⊙pnorm(s(t))+pnorm(b(t)),其中 pnorm ( ⋅ ) \text{pnorm}(\cdot) pnorm(⋅) 表示像素归一化。这种改进有助于稳定CM训练。
这些技术的应用,如在图4中可视化的,展示了我们如何稳定在CIFAR-10上训练的CMs的时间导数。我们发现这些改进有助于稳定CMs的训练动态,而不会损害扩散模型训练(见附录G)。
4.2 训练目标
使用第3节中的TrigFlow表述和第4.1节中提出的技术,连续时间CM训练的梯度在方程(2)中变为:
∇
θ
E
x
t
,
t
[
−
w
(
t
)
σ
d
sin
(
t
)
F
θ
⊤
(
x
t
σ
d
,
t
)
d
f
θ
−
(
x
t
,
t
)
d
t
]
.
\nabla_{\theta} \mathbb{E}_{x_t, t} \left[ - w(t) \sigma_d \sin(t) F_{\theta}^{\top} \left( \frac{x_t}{\sigma_d}, t \right) \frac{d f_{\theta}^{-}(x_t, t)}{d t} \right].
∇θExt,t[−w(t)σdsin(t)Fθ⊤(σdxt,t)dtdfθ−(xt,t)].
我们提出以下额外的技术来显式控制这个梯度,以改进稳定性:
-
切线归一化(Tangent Normalization):
由于CM训练中的梯度方差主要来自切线函数 d f θ − ( x t , t ) d t \frac{d f_{{\theta}^{-}}(x_t, t)}{d t} dtdfθ−(xt,t),我们提出显式归一化切线函数,用 d f θ − ( x t , t ) d t / ( ∥ d f θ − ( x t , t ) d t ∥ + c ) \frac{d f_{{\theta}^{-}}(x_t, t)}{d t} / (\| \frac{d f_{{\theta}^{-}}(x_t, t)}{d t} \| + c) dtdfθ−(xt,t)/(∥dtdfθ−(xt,t)∥+c)替换 d f θ − ( x t , t ) d t \frac{d f_{{\theta}^{-}}(x_t, t)}{d t} dtdfθ−(xt,t),其中我们经验性地设置 c = 0.1 c = 0.1 c=0.1。或者,我们可以将切线限制在 [ − 1 , 1 ] [-1, 1] [−1,1] 内,这也限制了其方差。我们在图5(a)中的结果表明,归一化或截断都显著改进了连续时间CMs的训练。 -
自适应加权(Adaptive Weighting):
先前的工作(Song & Dhariwal, 2023; Geng et al., 2024)为CM训练手动设计权重函数 w ( t ) w(t) w(t),这可能对不同的数据分布和网络架构来说是次优的。遵循EDM2(Karras et al., 2024),我们提出同时训练一个自适应权重函数,这不仅减轻了超参数调整的负担,而且以更好的经验性能超越了手动设计的权重函数,并且训练开销可以忽略不计。我们的方法的关键是观察到 ∇ θ E [ F θ ⊤ y ] = 1 2 ∇ θ E [ ∥ F θ − F θ − + y ∥ 2 2 ] \nabla_{\theta} \mathbb{E} [F_{\theta}^{\top} y] = \frac{1}{2} \nabla_{\theta} \mathbb{E} [\| F_{\theta} - F_{{\theta}^{-}} + y \|_2^2] ∇θE[Fθ⊤y]=21∇θE[∥Fθ−Fθ−+y∥22],其中 y y y是一个与 θ \theta θ无关的任意向量。当使用方程(2)训练连续时间CMs时,我们有 y = − w ( t ) σ d sin ( t ) d f θ − ( x t , t ) d t y = -w(t) \sigma_d \sin(t) \frac{d f_{{\theta}^{-}}(x_t, t)}{d t} y=−w(t)σdsin(t)dtdfθ−(xt,t)。这个观察使我们能够将方程(2)转换为均方误差(MSE)目标的梯度。因此,我们可以使用Karras等人(2024)中的相同方法训练一个自适应权重函数,该函数最小化了时间步上的MSE损失的方差(详细信息见附录D)。在实践中,我们发现整合先验权重 w ( t ) = 1 σ d tan ( t ) w(t) = \frac{1}{\sigma_d} \tan(t) w(t)=σd1tan(t)进一步减少了训练方差。通过整合先验权重,我们通过最小化以下目标训练网络 F θ F_{\theta} Fθ和自适应权重函数 w ϕ ( t ) w_{\phi}(t) wϕ(t):
L sCM ( θ , ϕ ) : = E x t , t [ e D w ϕ ( t ) ∥ F θ ( x t σ d , t ) − F θ − ( x t σ d , t ) − cos ( t ) d f θ − ( x t , t ) d t ∥ 2 2 − w ϕ ( t ) ] , L_{\text{sCM}}(\theta, \phi) :=\mathbb{E}_{x_t, t} \left[ e^{w_{\phi}(t)}_D \left\| F_{\theta} \left( \frac{x_t}{\sigma_d}, t \right) - F_{{\theta}^{-}} \left( \frac{x_t}{\sigma_d}, t \right) - \cos(t) \frac{d f_{{\theta}^{-}}(x_t, t)}{d t} \right\|_2^2 - w_{\phi}(t) \right], LsCM(θ,ϕ):=Ext,t[eDwϕ(t) Fθ(σdxt,t)−Fθ−(σdxt,t)−cos(t)dtdfθ−(xt,t) 22−wϕ(t)],
其中 D D D是 x 0 x_0 x0的维度,我们从对数正态提议分布中采样 tan ( t ) \tan(t) tan(t)(Karras et al., 2022),即 e σ d tan ( t ) ∼ N ( P mean , P std 2 ) e^{\sigma_d \tan(t)} \sim N(P_{\text{mean}}, P^2_{\text{std}}) eσdtan(t)∼N(Pmean,Pstd2)。 -
扩散微调和切线预热(Diffusion Finetuning and Tangent Warmup):
对于一致性蒸馏,我们发现从预训练的扩散模型微调CM可以加速收敛,这与Song等人(2023)和Geng等人(2024)一致。回想在方程(6)中,切线 d f θ − ( x t , t ) d t \frac{d f_{\theta}^{-}(x_t, t)}{d t} dtdfθ−(xt,t)可以分解为两部分:第一项 cos ( t ) ( σ d F θ − − d x t d t ) \cos(t)(\sigma_d F_{{\theta}^{-}} - \frac{d x_t}{d t}) cos(t)(σdFθ−−dtdxt)相对稳定,而第二项 sin ( t ) ( x t + σ d d F θ − d t ) \sin(t)(x_t + \sigma_d \frac{d F_{{\theta}^{-}}}{d t}) sin(t)(xt+σddtdFθ−)引起不稳定性。为了缓解这个问题,我们通过将系数 sin ( t ) \sin(t) sin(t)替换为 r ⋅ sin ( t ) r \cdot \sin(t) r⋅sin(t),其中 r r r在前10k次训练迭代中从0线性增加到1,逐步预热第二项。
通过所有这些技术,无论是离散时间还是连续时间CMs的训练稳定性都得到了显著提高。我们在附录E中提供了离散时间CMs的详细算法,并用相同的设置训练连续时间CMs和离散时间CMs。正如图5©所示,增加离散时间CMs中的离散化步数 N N N通过减少离散化误差提高了样本质量,但一旦 N N N变得太大(超过1024)就会因数值精度问题而退化。相比之下,连续时间CMs在所有 N N N的情况下都显著优于离散时间CMs,这为选择连续时间CMs而不是离散时间对应物提供了强有力的理由。我们将我们的模型称为sCM(s代表简单、稳定和可扩展),并在附录A中提供了sCM训练的详细伪代码。
5.扩展连续时间一致性模型
下面我们通过在各种具有挑战性的数据集上训练大规模scm来测试前几节中提出的所有改进。
5.1 切线向量的计算
在训练大规模扩散模型的常见设置中,包括使用半精度(FP16)和Flash Attention(Dao et al., 2022; Dao, 2023)。由于训练连续时间CMs需要准确计算切线向量 d f θ − ( x t , t ) d t \frac{d f_{\theta}^{-}(x_t, t)}{d t} dtdfθ−(xt,t),我们需要提高数值精度,并支持内存高效的注意力计算,具体如下:
-
JVP重新排列(JVP Rearrangement):
计算切线向量涉及计算 d F θ − d t = ∇ x t F θ − ⋅ d x t d t + ∂ F θ − ∂ t \frac{d F_{{\theta}^{-}}}{d t} = \nabla_{x_t} F_{{\theta}^{-}} \cdot \frac{d x_t}{d t} + \frac{\partial F_{{\theta}^{-}}}{\partial t} dtdFθ−=∇xtFθ−⋅dtdxt+∂t∂Fθ−,这可以通过Fθ−的Jacobian-vector product(JVP)高效获得,输入向量为 ( x t , t ) (x_t, t) (xt,t) 和切线向量 ( d x t d t , 1 ) \left( \frac{d x_t}{d t}, 1 \right) (dtdxt,1)。然而,我们发现当 t t t 接近0或 π 2 \frac{\pi}{2} 2π 时,切线向量在中间层可能会溢出。为了提高数值精度,我们提出重新排列切线向量的计算。具体来说,由于目标方程(8)中的 cos ( t ) d f θ − ( x t , t ) d t \cos(t) \frac{d f_{{\theta}^{-}}(x_t, t)}{d t} cos(t)dtdfθ−(xt,t) 和 d f θ − ( x t , t ) d t \frac{d f_{{\theta}^{-}}(x_t, t)}{d t} dtdfθ−(xt,t) 成正比,我们可以计算JVP为:
cos ( t ) sin ( t ) d F θ − d t = ∇ x t F θ − ( cos ( t ) sin ( t ) d x t d t ) + ∂ F θ − ∂ t ( cos ( t ) sin ( t ) σ d ) . \cos(t) \sin(t) \frac{d F_{{\theta}^{-}}}{d t} = \nabla_{x_t} F_{{\theta}^{-}} \left( \cos(t) \sin(t) \frac{d x_t}{d t} \right) + \frac{\partial F_{{\theta}^{-}}}{\partial t} \left( \cos(t) \sin(t) \sigma_d \right). cos(t)sin(t)dtdFθ−=∇xtFθ−(cos(t)sin(t)dtdxt)+∂t∂Fθ−(cos(t)sin(t)σd). 这种重新排列大大减轻了中间层的溢出问题,从而在FP16中实现了更稳定的训练。 -
Flash Attention的JVP(JVP of Flash Attention):
Flash Attention(Dao et al., 2022; Dao, 2023)在大规模模型训练中被广泛用于注意力计算,提供了GPU内存节省和更快的训练。然而,Flash Attention不计算Jacobian-vector product(JVP)。为了填补这一空白,我们提出了一个类似的算法(详细见附录F),该算法在Flash Attention的风格中高效计算softmax自注意力及其JVP,显著减少了注意力层中JVP计算的GPU内存使用。
通过这些改进,我们成功地在没有训练不稳定性的情况下扩展了连续时间CMs。我们在CIFAR-10、ImageNet 64×64和ImageNet 512×512上训练了各种大小的sCMs,使用EDM2配置(S, M, L, XL, XXL),并评估了在最优引导尺度下的FID,如图6所示。首先,随着模型FLOPs的增加,sCT和sCD都显示出改进的样本质量,表明这两种方法都从扩展中受益。其次,与sCD相比,sCT在较小分辨率下更计算效率高但在较大分辨率下效率较低。第三,sCD对于给定数据集可预测地扩展,保持了跨模型大小的一致相对差异。这表明sCD的FID以与教师扩散模型相同的速率下降,因此sCD与教师扩散模型一样可扩展。随着教师扩散模型的FID随着扩展而下降,sCD和教师模型之间的绝对FID差异也减小。最后,随着采样步骤的增加,相对差异减小,两步sCD的样本质量变得与教师扩散模型相当。
在PDF文件的第5.2节中,讨论了通过在各种具有挑战性的数据集上训练来扩展连续时间一致性模型(CMs)的实验。以下是该节内容的完整翻译:
5.2 实验
为了测试我们在前几节中提出的改进,我们使用一致性训练(称为sCT)和一致性蒸馏(称为sCD)在CIFAR-10、ImageNet 64×64和ImageNet 512×512上训练和扩展连续时间CMs。我们使用FID(Heusel et al., 2017)作为样本质量的基准。我们遵循Score SDE(Song et al., 2021b)在CIFAR-10上的设置,以及EDM2(Karras et al., 2024)在ImageNet 64×64和ImageNet 512×512上的设置,同时根据第4.1节中的讨论改变参数化和架构。我们采用了Song et al.(2023)提出的方法进行sCT和sCD的两步采样,使用固定的中间时间步 t = 1.1 t = 1.1 t=1.1。对于ImageNet 512×512上的sCD模型,由于教师扩散模型依赖于分类器自由引导(CFG)(Ho & Salimans, 2021),我们在模型 F θ F_{\theta} Fθ中增加了额外的输入 s s s来表示引导尺度(Meng et al., 2023)。我们通过均匀采样 s ∈ [ 1 , 2 ] s \in [1, 2] s∈[1,2]并应用相应的CFG到教师模型中进行sCD训练(更多细节见附录G)。对于sCT模型,我们没有测试CFG,因为它与一致性训练不兼容。
sCM的训练计算。我们使用与教师扩散模型相同的批量大小进行所有数据集的训练。sCD的有效训练计算每次迭代大约是教师模型的两倍。我们观察到,sCD的两步样本质量迅速收敛,使用不到教师训练计算的20%就实现了与教师扩散模型相当的结果。在实践中,我们可以使用sCD在仅20k次微调迭代后获得高质量的样本。
基准测试。在表1和表2中,我们通过基准测试FIDs和函数评估(NFEs)的数量,将我们的结果与先前的方法进行比较。首先,sCM在不依赖联合训练的其他少步骤方法中表现最佳,并且与使用对抗性训练的最佳结果相当或更好。值得注意的是,sCD-XXL在ImageNet 512×512上的1步FID超过了StyleGAN-XL(Sauer et al., 2022)和VAR(Tian et al., 2024a)。此外,sCD-XXL的两步FID优于除扩散之外的所有生成模型,并且与需要63个连续步骤的最佳扩散模型相当。其次,两步sCM显著缩小了与教师模型的FID差距,到10%以内,实现了CIFAR-10上的FID为2.06(教师FID为2.01),ImageNet 64×64上的FID为1.48(教师FID为1.33),以及ImageNet 512×512上的FID为1.88(教师FID为1.73)。此外,我们观察到sCT在较小规模时更有效,但在较大规模时方差增加,而sCD在小规模和大规模上都表现出一致的性能。
扩展研究。基于我们改进的训练技术,我们成功地扩展了连续时间CMs而没有训练不稳定性。我们在ImageNet 64×64和512×512上使用EDM2配置(S, M, L, XL, XXL)训练了各种大小的sCMs,并在最优引导尺度下评估了FID,如图6所示。首先,随着模型FLOPs的增加,sCT和sCD都显示出改进的样本质量,表明这两种方法都从扩展中受益。其次,与sCD相比,sCT在较小分辨率下更计算效率高但在较大分辨率下效率较低。第三,sCD对于给定数据集可预测地扩展,保持了跨模型大小的一致相对差异。这表明sCD的FID以与教师扩散模型相同的速率下降,因此sCD与教师扩散模型一样可扩展。随着教师扩散模型的FID随着扩展而下降,sCD和教师模型之间的绝对FID差异也减小。最后,随着采样步骤的增加,相对差异减小,两步sCD的样本质量变得与教师扩散模型相当。
与VSD的比较。变分得分蒸馏(VSD)(Wang et al., 2024; Yin et al., 2024b)及其多步概括(Xie et al., 2024b; Salimans et al., 2024)是另一种已经证明在高分辨率图像上可扩展的扩散蒸馏技术。我们应用从时间T到0的一步VSD来微调教师扩散模型,使用EDM2-M配置,并调整了权重函数和提议分布以进行公平比较。如图7所示,我们通过在变化的引导尺度上进行扫描,比较了sCD、VSD、VSD和sCD的组合(通过简单地添加两个损失)以及教师扩散模型。我们观察到VSD具有类似于在扩散模型中应用大引导尺度的伪影:它增加了保真度(由更高的精度分数证明),同时降低了多样性(由更低的召回分数证明)。这种效果随着增加的引导尺度变得更加明显,最终导致严重的模式崩溃。相比之下,两步sCD的精度和召回分数与教师扩散模型相当,从而实现了比VSD更好的FID分数。
在PDF文件的第6节中,总结了研究成果并讨论了其意义。以下是该节内容的完整翻译:
6 结论
我们提出的改进公式、架构和训练目标简化并稳定了连续时间一致性模型的训练,使我们能够将模型扩展到在ImageNet 512×512上的15亿参数。我们对TrigFlow公式、切线归一化和自适应加权的影响进行了消融研究,确认了它们的有效性。结合这些改进,我们的方法在不同数据集和模型大小上展示了可预测的可扩展性,超过了大规模情况下其他少步骤采样方法的性能。值得注意的是,我们在使用两步生成的情况下,将与教师模型的FID差距缩小到10%以内,与需要显著更多采样步骤的最先进的扩散模型相比。