扩散模型设计选项的全面拆解与分析:EDM
EDM 是扩散模型形式化研究中的一篇重要工作。在当时,DDPM、SMLD、SDE 等扩散模型工作陆续被提出,它们的建模角度不尽相同,却又有着千丝万缕的联系。EDM 指出现有的扩散模型设计过于杂糅,并构建了一个清晰的扩散模型设计空间,其中各种具体的设计选项解耦开来,从而能够更精确地分析不同设计选项对于采样、训练各过程的实际影响。
统一的框架
我们首先基于 Song 等人提出的 SDE 形式,对已有的扩散模型理论进行整理,构建一个统一的理论框架,其中各种扩散模型设计中的各种选项相互解耦开来,便于我们进一步通过推导或实验确定通用的最佳选项。
之前工作中的 ODE 形式
记数据分布为 p data ( x ) p_\text{data}(\mathbf{x}) pdata(x),其方差为 σ data 2 \sigma_\text{data}^2 σdata2。通过向数据添加方差为 σ 2 \sigma^2 σ2 的不同强度的高斯噪声,可以得到一系列加噪分布 p ( x ; σ ) p(\mathbf{x};\sigma) p(x;σ)。当噪声的 sigma 足够大, σ max > > σ data \sigma_\text{max}>>\sigma_\text{data} σmax>>σdata ,此时 p ( x ; σ max ) p(\mathbf{x};\sigma_\text{max}) p(x;σmax) 中就几乎看不出数据,成为完全的的噪声分布了。扩散模型的想法就是,随机采样一个起始纯噪声图 x ∼ N ( 0 , σ max 2 I ) \mathbf{x}\sim\mathcal{N}(0,\sigma_\text{max}^2\mathbf{I}) x∼N(0,σmax2I),并逐步对其进行去噪,每一步得到噪声稍小的 x i ∼ p ( x i ; σ i ) \mathbf{x}_i\sim p(\mathbf{x}_i;\sigma_i) xi∼p(xi;σi),最终就能得到服从 p data p_\text{data} pdata 的干净数据了。
在 Song 等人的工作中,扩散模型 SDE 定义为如下形式:
d
x
=
f
(
x
,
t
)
d
t
+
g
(
t
)
d
w
,
forward
d
x
=
[
f
(
x
,
t
)
−
g
(
t
)
∇
x
log
p
(
x
)
]
d
t
+
g
(
t
)
d
w
~
,
backward
(1)
d\mathbf{x}=f(\mathbf{x},t)dt+g(t)d\mathbf{w},\quad \text{forward} \\ d\mathbf{x}=[f(\mathbf{x},t)-g(t)\nabla_\mathbf{x}\log p(\mathbf{x})]dt+g(t)d\tilde{\mathbf{w}},\quad \text{backward} \tag{1}
dx=f(x,t)dt+g(t)dw,forwarddx=[f(x,t)−g(t)∇xlogp(x)]dt+g(t)dw~,backward(1)
一般来说,
f
(
⋅
)
f(\cdot)
f(⋅) 取关于
x
\mathbf{x}
x 的线性形式,即
f
(
x
,
t
)
=
f
(
t
)
x
f(\mathbf{x},t)=f(t)\mathbf{x}
f(x,t)=f(t)x,上式的前向过程可写为
d
x
=
f
(
t
)
x
d
t
+
g
(
t
)
d
w
d\mathbf{x}=f(t)\mathbf{x}dt+g(t)d\mathbf{w} \notag
dx=f(t)xdt+g(t)dw。
这个 SDE 对应了一个具有完全相同边缘分布
p
t
(
x
)
p_t(\mathbf{x})
pt(x) 的概率流 ODE 形式:
d
x
=
[
f
(
t
)
x
−
1
2
∇
x
log
p
(
x
)
]
d
t
(2)
d\mathbf{x}=\left[f(t)\mathbf{x}-\frac{1}{2}\nabla_\mathbf{x}\log p(\mathbf{x})\right]dt \tag{2}
dx=[f(t)x−21∇xlogp(x)]dt(2)
Song 等人还给出了从时刻
0
0
0 到
t
t
t 加噪过程中扰动核的一般形式:
p
0
t
(
x
t
∣
x
0
)
=
N
(
x
t
;
s
(
t
)
x
0
,
s
2
(
t
)
σ
2
(
t
)
I
)
i
.
e
.
x
t
=
s
(
t
)
x
0
+
s
(
t
)
σ
(
t
)
ϵ
其中
s
(
t
)
=
exp
(
∫
0
t
f
(
ξ
)
d
ξ
)
,
σ
(
t
)
=
∫
0
t
g
2
(
ξ
)
s
2
(
ξ
)
d
ξ
(3)
p_{0t}(\mathbf{x}_t|\mathbf{x}_0)=\mathcal{N}(\mathbf{x}_t;s(t)\mathbf{x}_0,s^2(t)\sigma^2(t)\mathbf{I}) \\ i.e.\quad \mathbf{x}_t=s(t)\mathbf{x}_0+s(t)\sigma(t)\epsilon \\ 其中\quad s(t)=\exp\left(\int_0^tf(\xi)d\xi\right),\quad\sigma(t)=\sqrt{\int_0^t\frac{g^2(\xi)}{s^2(\xi)}}d\xi \tag{3}
p0t(xt∣x0)=N(xt;s(t)x0,s2(t)σ2(t)I)i.e.xt=s(t)x0+s(t)σ(t)ϵ其中s(t)=exp(∫0tf(ξ)dξ),σ(t)=∫0ts2(ξ)g2(ξ)dξ(3)
可以看到,SDE/ODE 的形式是由 f , g f,g f,g 定义的。但是实际上,边缘分布 p t ( x ) p_t(\mathbf{x}) pt(x) 才是我们在扩散模型形式设计中最关心的东西,它关系到模型训练、采样加速以及如何理解 ODE 的实际行为。而 f , g f,g f,g 与边缘分布的关系并不直接。因此,EDM 中作者认为直接从 s , σ s,\sigma s,σ 的角度出发来定义扩散模型的规格是更合理、更有实际意义的。接下来,我们将推到从 s , σ s,\sigma s,σ 的角度来从新推导 ODE 形式。
ODE 形式重新推导
将边缘分布写成积分形式并进行推导(原文式 15-19),有如下结果:
p
t
(
x
)
=
∫
R
d
p
0
t
(
x
∣
x
0
)
p
data
(
x
0
)
d
x
0
=
.
.
.
=
s
(
t
)
−
d
[
p
data
∗
N
(
0
,
σ
2
(
t
)
I
)
]
(
x
/
s
(
t
)
)
\begin{aligned} p_t(\mathbf{x})&=\int_{\mathbb{R}^d}p_{0t}(\mathbf{x}|\mathbf{x}_0)p_\text{data}(\mathbf{x}_0)d\mathbf{x}_0 \\ &=... \\ &=s(t)^{-d}[p_\text{data}*\mathcal{N}(0,\sigma^2(t)\mathbf{I})](\mathbf{x}/s(t)) \end{aligned} \notag
pt(x)=∫Rdp0t(x∣x0)pdata(x0)dx0=...=s(t)−d[pdata∗N(0,σ2(t)I)](x/s(t))
其中
∗
*
∗ 表示两个概率密度函数的卷积操作,
d
d
d 是数据集的维度。方括号内的卷积操作相当于是对数据分布进行加噪扰动,我们将这样得到的加噪分布记为
p
(
x
;
σ
)
=
p
data
∗
N
(
0
,
σ
2
(
t
)
I
)
p(\mathbf{x};\sigma)=p_\text{data}*\mathcal{N}(0,\sigma^2(t)\mathbf{I})
p(x;σ)=pdata∗N(0,σ2(t)I)。从而可以用加噪分布
p
(
x
;
σ
)
p(\mathbf{x};\sigma)
p(x;σ) 来表示边缘分布
p
t
(
x
)
p_t(\mathbf{x})
pt(x):
p
t
(
x
)
=
s
(
t
)
−
d
p
(
x
/
s
(
t
)
;
σ
(
t
)
)
p_t(\mathbf{x})=s(t)^{-d}p(\mathbf{x}/s(t);\sigma(t)) \notag
pt(x)=s(t)−dp(x/s(t);σ(t))
将这个关系带入到概率流 ODE (式 2)中,并整理得到:
d
x
=
.
.
.
=
[
f
(
t
)
x
−
1
2
g
2
(
t
)
∇
x
log
(
p
(
x
/
s
(
t
)
,
σ
(
t
)
)
]
d
t
(4)
d\mathbf{x}=...=[f(t)\mathbf{x}-\frac{1}{2}g^2(t)\nabla_\mathbf{x}\log(p(\mathbf{x}/s(t),\sigma(t))]dt \tag{4}
dx=...=[f(t)x−21g2(t)∇xlog(p(x/s(t),σ(t))]dt(4)
基于式 3,我们可以反写出用
s
,
σ
s,\sigma
s,σ 表示
f
,
g
f,g
f,g 的形式:
f
(
t
)
=
s
˙
(
t
)
/
s
(
t
)
g
(
t
)
=
s
(
t
)
2
σ
˙
(
t
)
σ
(
t
)
f(t)=\dot{s}(t)/s(t)\quad g(t)=s(t)\sqrt{2\dot{\sigma}(t)\sigma(t)} \notag
f(t)=s˙(t)/s(t)g(t)=s(t)2σ˙(t)σ(t)
其中上面的点表示对时间
t
t
t 的微分。现在,可以将式 (4) 中的
f
,
g
f,g
f,g 也替换成
s
,
σ
s,\sigma
s,σ:
d
x
=
.
.
.
=
[
s
˙
(
t
)
s
(
t
)
x
−
s
2
(
t
)
σ
˙
(
t
)
σ
(
t
)
∇
x
log
p
(
x
s
(
t
)
;
σ
(
t
)
)
]
d
t
(5)
d\mathbf{x}=...=\left[\frac{\dot{s}(t)}{s(t)}\mathbf{x}-s^2(t)\dot{\sigma}(t)\sigma(t)\nabla_{\mathbf{x}}\log p(\frac{\mathbf{x}}{s(t)};\sigma(t))\right]dt \tag{5}
dx=...=[s(t)s˙(t)x−s2(t)σ˙(t)σ(t)∇xlogp(s(t)x;σ(t))]dt(5)
这就得到了 EDM 中基于
s
,
σ
s,\sigma
s,σ 的 ODE 方程。我们再来理解一下
s
,
σ
s,\sigma
s,σ 的含义,
s
(
t
)
s(t)
s(t) 实际上是对于原始
x
^
\hat{\mathbf{x}}
x^ 的一个关于时间
t
t
t 的缩放因子
x
=
s
(
t
)
x
^
\mathbf{x}=s(t)\hat{\mathbf{x}}
x=s(t)x^,
σ
(
t
)
\sigma(t)
σ(t) 则是在
t
t
t 时刻加噪分布的标准差,也就是我们常说的噪声计划表 noise schedule。注意这里显式地撤销了
s
(
t
)
s(t)
s(t) 的缩放,从而保持
p
(
x
;
σ
)
p(\mathbf{x};\sigma)
p(x;σ) 是独立于
s
(
t
)
s(t)
s(t) 的。 可以看到,在这个式子中,我们唯一未知的就是得分
∇
x
log
p
(
x
;
σ
)
\nabla_\mathbf{x}\log p(\mathbf{x};\sigma)
∇xlogp(x;σ),这正是我们的神经网络需要去学习估计的东西。
训练目标:去噪得分匹配
在训练时,扩散模型一般采用去噪得分匹配(denoising score matching)的目标函数。这种方法绕开了对无法计算的归一化常数
Z
Z
Z,直接计算加噪分布的得分。具体来说,如果我们记
D
(
x
;
σ
)
D(\mathbf{x};\sigma)
D(x;σ) 是一个去噪函数,它能够最小化下面的 L2 误差:
E
y
∼
p
data
E
n
∼
N
(
0
,
σ
2
I
)
∣
∣
D
(
y
+
n
;
σ
)
−
y
∣
∣
2
2
\mathbb{E}_{\mathbf{y}\sim p_\text{data}}\mathbb{E}_{\mathbf{n}\sim\mathcal{N}(0,\sigma^2\mathbf{I})}||D(\mathbf{y}+\mathbf{n};\sigma)-\mathbf{y}||_2^2 \notag
Ey∼pdataEn∼N(0,σ2I)∣∣D(y+n;σ)−y∣∣22
经推导(附录 B.3 ),有:
∇
x
log
p
(
x
;
σ
)
=
(
D
(
x
;
σ
)
−
x
)
/
σ
2
\nabla_\mathbf{x}\log p(\mathbf{x};\sigma)=(D(\mathbf{x};\sigma)-\mathbf{x})/\sigma^2 \notag
∇xlogp(x;σ)=(D(x;σ)−x)/σ2
其中
y
\mathbf{y}
y 是干净数据,
n
\mathbf{n}
n 是噪声。从上式的角度来看,得分函数的实际上在做的事情就是将噪声从干净数据中分离出来,即所谓去噪。扩散模型中,我们将上式作为目标函数,训练一个神经网络
D
θ
(
x
;
σ
)
D_\theta(\mathbf{x};\sigma)
Dθ(x;σ) 来估计这个去噪函数
D
D
D。
ODE 离散化与求解
现在,我们去噪得分匹配的目标函数训练了一个神经网络 D θ D_\theta Dθ,估计出了加噪分布的得分,将这个得分代回到式 5,我们就得到了一个可求解的 ODE。接下来的采样过程,就是要在给定初值 x T \mathbf{x}_T xT 的条件下,求解出 x 0 \mathbf{x}_0 x0,这样就得到了符合数据分布 p data p_\text{data} pdata 的干净数据。
注意,我们现在的采样过程完全就是在解一个已知 ODE 的初值问题,只不过每一步 ODE 中的某些项(得分)需要我们跑一下神经网络估计出来(这就是一次所谓的 NFE Network Function Evaluation)。既然是求解 ODE,那么采样过程就我们需要确定两件事情:一是采用何种求解器(ODE Solver),二是采用什么样的求解步长离散策略(Schedule)。
preconditioning和损失权重
在训练神经网络时,有一个非常有用的经验是:我们最好保持网络的输入输出在一个稳定的数值范围内(比如具有单位方差)。这样可以避免梯度的数值有大范围的波动,从而稳定训练。而在训练扩散模型来拟合去噪函数 D D D 时,由于我们需要向干净数据中加入方差不等的高斯噪声,因此是无法满足这一要求的。我们一般考虑改为训练神经网络 F θ F_\theta Fθ(而非直接训练 D θ D_\theta Dθ),再通过适当的形式化调整计算出去噪函数 D D D。之前的工作一般是引入一个关于 σ \sigma σ 的归一化项,将噪声 n \mathbf{n} n 归一化到单位方差,让模型来预测这个单位方差的噪声,再计算出去噪函数 D θ = x − σ F θ ( ⋅ ) D_\theta=\mathbf{x}-\sigma F_\theta(\cdot) Dθ=x−σFθ(⋅)。作者认为,这样做的问题是当 σ \sigma σ 非常大时,网络预测的一点微小错误都会被这个很大的 σ \sigma σ 进一步放大,网络的学习比较困难,此时反而是直接预测 D D D 对模型来说比较简单。
为了改善这个问题,作者参考之前自适应混合信号噪声的方法,提出了一种新的扩散模型形式化。通过引入一个关于
σ
\sigma
σ 的跳跃连接来对模型进行 preconditioning,这样模型可以预测
y
\mathbf{y}
y,也可以预测
n
\mathbf{n}
n,也可以预测二者之间。具体来说,作者将
D
θ
D_\theta
Dθ 构造成如下形式:
D
θ
(
x
;
σ
)
=
c
skip
(
σ
)
x
+
c
out
(
σ
)
F
θ
(
c
in
(
σ
)
x
;
c
noise
(
σ
)
)
(6)
D_\theta(\mathbf{x};\sigma)=c_\text{skip}(\sigma)\mathbf{x}+c_\text{out}(\sigma)F_\theta\left(c_\text{in}(\sigma)\mathbf{x};c_\text{noise}(\sigma)\right) \tag{6}
Dθ(x;σ)=cskip(σ)x+cout(σ)Fθ(cin(σ)x;cnoise(σ))(6)
其中
F
θ
F_\theta
Fθ 是实际训练的神经网络,
c
skip
(
σ
)
c_\text{skip}(\sigma)
cskip(σ) 用于调制跳跃连接,
c
in
(
σ
)
,
c
out
(
σ
)
c_\text{in}(\sigma),c_\text{out}(\sigma)
cin(σ),cout(σ) 分别用于对输入和输出进行缩放,
c
noise
(
σ
)
c_\text{noise}(\sigma)
cnoise(σ) 将噪声等级
σ
\sigma
σ 转换为网络的输入条件。
我们之前提到 DDPM 通过修改原扩散模型的形式化,从预测原图改为预测噪声,大大提升了出图的质量。这里 EDM 实际上是将扩散模型的形式化拓展成了一种更一般的形式。其中 c skip , c out , c in , c noise c_\text{skip},c_\text{out},c_\text{in},c_\text{noise} cskip,cout,cin,cnoise 的选取,扩散模型训练过程设计选项的关键。
此外,在训练过程中,损失的加权方式 λ ( σ ) \lambda(\sigma) λ(σ) 和采样 σ \sigma σ 的分布,也很重要。
总结
以上我们已经将整个扩散模型的设计空间(包括训练过程、采样过程)的所有规格可选项全部梳理出来了。具体包括如下内容:
采样
- s ( t ) s(t) s(t):缩放因子
- σ ( t ) \sigma(t) σ(t):噪声的标准差
- { t i } \{t_i\} {ti}:离散化的步长
- ODE Solver:所选用的 ODE 求解器
训练
-
preconditioning: c skip ( σ ) c_\text{skip}(\sigma) cskip(σ):用于调制跳跃连接; c out ( σ ) c_\text{out}(\sigma) cout(σ):输出缩放因子; c in ( σ ) c_\text{in}(\sigma) cin(σ):输入缩放因子; c noise ( σ ) c_\text{noise}(\sigma) cnoise(σ):噪声条件化
-
损失加权: λ ( σ ) \lambda(\sigma) λ(σ)
-
采样 σ \sigma σ 的分布
在 EDM 论文中,作者选取了 DDPM(VP SDE)、SMLD(VE SDE)、improved DDPM training + DDIM sampling 三种方法作对比分析通过理论分析和实验,设计出了自己的 EDM,所有这些方法在 EDM 框架下的设计选项重新推到结果如下表所示。之所以选择这几项工作作为对比,一是因为它们都是当时 SOTA 的扩散生成模型,另一点原因,是它们分别来自不同的扩散模型理论,通过将他们全都在上述提出的扩散模型设计空间中进行统一的分析、对比,也验证了本文所提框架对不同扩散模型理论的统一性和有效性。
本文所提的扩散模型分析框架,将之前工作中耦合在一起的扩散模型规格设计解耦独立开来,各组件之间不再有隐式依赖关系,可以进行独立的优化选择,而不会影响彼此。
确定性采样
我们先来看如何基于新的框架,如何优化扩散模型的采样过程。这里我们先看 ODE 确定性采样。采样过程有两个优化目标。一是提升出图的质量,由 FID 指标来评估,越低越好。二是提高采样速度,采样过程时间消耗的大头是在模型推理上,因此可以用模型推理的次数 NFE 来反映采样速度。这两者之间存在权衡关系,我们需要在尽可能快的采样速度下得到尽量好的结果。
作者认为,采样过程中的各项参数选择,即缩放因子 s ( t ) s(t) s(t),噪声方差 σ ( t ) \sigma(t) σ(t) 和 ODE 离散化的策略 { t i } \{t_i\} {ti},与其他组件(模型结构和训练细节)应当是相互独立的。模型 D θ D_\theta Dθ 对于采样过程来说,可以看作是一个黑盒的得分估计器,其结构和训练方式不会影响采样参数的选择,采样参数也不应该决定模型的训练过程。作者对来自不同理论的扩散模型方法进行了统一的实验,测试各种采样选项对不同模型的影响,并提出通用的最优选项。
离散化和求解器
首先我们来看 ODE 的离散化方式和求解器的选择。
ODE 的数值求解实际上是对真实解轨迹的离散化近似。每一步都会存在误差,并且会逐步积累。常用的欧拉法一阶 ODE 求解器,其单步误差与步长的关系是 O ( h 2 ) \mathcal{O}(h^2) O(h2)。更高阶的龙格库塔等方法精度更高,但高阶求解器会需要更多的推理次数 NFE。作者通过广泛的实验来权衡 NFE 和误差,最终推荐选用二阶 Heun 方法。
二阶 Heun 求解器的具体算法步骤如图所示。其实是在欧拉方法的基础上加了一个校正步(6,7,8 三行,去掉这三行就是欧拉法),这样需要一次额外的模型推理,但是能将单步误差降低到 O ( h 3 ) \mathcal{O}(h^3) O(h3),从而在采样速度和误差之间达到最好的权衡。
离散化方式,也就是采样时选取的时间步
{
t
i
}
\{t_i\}
{ti},决定了每步的步长
h
h
h,从而决定了不同噪声程度下的误差分布。作者的结论是,步长应该随着
σ
\sigma
σ 的降低单调递减,并且没有必要根据不同样本采用不同的步长。作者根据
σ
\sigma
σ 构造了形如一个
σ
i
<
N
=
(
A
i
+
B
)
ρ
\sigma_{i<N}=(Ai+B)^\rho
σi<N=(Ai+B)ρ 参数化的步长策略,考虑到我们需要保证
σ
0
=
σ
max
,
σ
n
−
1
=
σ
min
\sigma_0=\sigma_\text{max},\sigma_{n-1}=\sigma_\text{min}
σ0=σmax,σn−1=σmin,具体公式写为:
σ
i
<
N
=
(
σ
max
1
ρ
+
i
N
−
1
(
ρ
min
1
ρ
−
σ
max
1
ρ
)
)
ρ
,
σ
N
=
0
(8)
\sigma_{i<N}=\left(\sigma_\text{max}^{\frac{1}{\rho}}+\frac{i}{N-1}(\rho_\text{min}^{\frac{1}{\rho}}-\sigma_\text{max}^{\frac{1}{\rho}})\right)^\rho ,\quad\sigma_N=0 \tag{8}
σi<N=(σmaxρ1+N−1i(ρminρ1−σmaxρ1))ρ,σN=0(8)
参数
ρ
\rho
ρ 越大,在
σ
\sigma
σ 较小处的步长越小,但同时
σ
\sigma
σ 较大处的步长也会越大。这样相当于让采样过程在
σ
\sigma
σ 较小处的求解更精细。当设置
ρ
=
3
\rho=3
ρ=3 时,在每一步的单步误差基本是相等的,但是当
ρ
∈
[
5
,
10
]
\rho\in[5,10]
ρ∈[5,10] 时出图质量更高。这表明在
σ
\sigma
σ 较小时的误差影响相对比较大。作者最终推荐的值是
ρ
=
7
\rho=7
ρ=7。
轨迹曲率和噪声计划表
然后我们再来看
σ
(
t
)
\sigma(t)
σ(t) 和
s
(
t
)
s(t)
s(t),这两项的选择决定了 ODE 解轨迹的形状。上面提到的 ODE 离散数值求解误差与
d
x
/
d
t
d\mathbf{x}/dt
dx/dt 的曲率是成比例的,因此选择一个更好的
s
,
σ
s,\sigma
s,σ 能够降低误差。作者直接给出最优的选择是取:
σ
(
t
)
=
t
,
s
(
t
)
=
1
\sigma(t)=t,\quad s(t)=1 \notag
σ(t)=t,s(t)=1
这与 DDIM 采样的选择是一样的。这样一来,ODE 式就可以简化为:
d
x
/
d
t
=
(
x
−
D
(
x
,
t
)
)
/
t
d\mathbf{x}/dt=(\mathbf{x}-D(\mathbf{x},t))/t \notag
dx/dt=(x−D(x,t))/t
注意现在
σ
\sigma
σ 和
t
t
t 是等价的了。这样选择一个直接的好处是在任何一个时刻
t
t
t,解轨迹的切线都指向起点
x
0
\mathbf{x}_0
x0。当模型训练地足够好时,在任何时刻
t
t
t 都能一步回到干净图片
x
0
\mathbf{x}_0
x0。顺便提一句,沿着这个分析走下去,我们马上就能得到大名鼎鼎的一致性模型(Consistency Models)了。
下图中对比了三种不同的 ODE 在一维数据上的解轨迹可视化结果。图 a 中的 VP-ODE 的解轨迹,在 $\sigma $ 较大时一直是水平的,直到 σ \sigma σ 很小之后才开始指向数据分布;图 b 中的 VE-ODE 的解轨迹,一直都是弯曲的;图 c 中是本文 edm 的(也是 ddim 的)解轨迹,随着 σ \sigma σ 增大,解轨迹接近指向数据均值的直线。
下图对比了基于不同理论方法的扩散模型(VP、VE 和 DDIM)在采用了 EDM 中提出的采样选项后,采样速度和出图质量都有所提升。
随机性采样
确定性采样有很多好处,比如可以对图片作 inversion 得到 latent,然后进行图像编辑。但是,实际中确定性采样的出图质量通常不如随机性采样。
Song 等人的 SDE 可以推广为如下形式:
d
x
±
=
−
σ
˙
(
t
)
σ
(
t
)
∇
x
log
p
(
x
;
σ
(
t
)
)
d
t
⏟
PF ODE
±
β
(
t
)
σ
2
(
t
)
∇
x
log
(
x
;
σ
(
t
)
)
d
t
⏟
确定性的噪声衰减
+
2
β
(
t
)
σ
(
t
)
d
ω
t
⏟
随机噪声注入
⏟
朗之万扩散 SDE
(9)
d\mathbf{x}_{\pm} =\underbrace{-\dot{\sigma}(t)\sigma(t)\nabla_\mathbf{x}\log p(\mathbf{x};\sigma(t))dt}_{\text{PF ODE}} \pm\underbrace{\underbrace{\beta(t)\sigma^2(t)\nabla_\mathbf{x}\log(\mathbf{x};\sigma(t))dt}_{确定性的噪声衰减} +\underbrace{\sqrt{2\beta(t)}\sigma(t)d\omega_t}_{随机噪声注入}}_{\text{朗之万扩散 SDE}} \tag{9}
dx±=PF ODE
−σ˙(t)σ(t)∇xlogp(x;σ(t))dt±朗之万扩散 SDE
确定性的噪声衰减
β(t)σ2(t)∇xlog(x;σ(t))dt+随机噪声注入
2β(t)σ(t)dωt(9)
其中 ω ( t ) \omega(t) ω(t) 是标准维纳过程, d x + , d x − d\mathbf{x}_+,d\mathbf{x}_- dx+,dx− 分别表示 SDE 前向和反向移动。该 SDE 形式可以看做在 PF-ODE 的基础上添加了朗之万扩散 SDE 项,而该项又可被拆解分一个确定性的去噪项和一个随机的噪声注入项,它们对净噪声水平的贡献相互抵消。 β ( t ) \beta(t) β(t) 用于控制现有的噪声被替代为新噪声的比率,当取 β ( t ) = σ ˙ ( t ) σ ( t ) \beta(t)=\dot{\sigma}(t)\sigma(t) β(t)=σ˙(t)σ(t) 时,前向过程中的 score 项就没了,这就回到了 Song 等人提出的 SDE 形式。
从这个角度,我们可以看到为什么引入随机性对出图质量是有帮助的:隐式的朗之万扩散将样本推向特定时刻的目标边缘分布,积极纠正采样早期出现的误差。另一方面,用离散的 SDE 求解器步骤近似朗之万项本身会引入误差。已有工作证明确定性的噪声衰减这一项( β ( t ) ≠ 0 \beta(t)\ne0 β(t)=0)是有帮助的, β ( t ) \beta(t) β(t) 实际如何确定,也应该通过实验来分析。
本文提出了的随机采样器(如下算法所示),在二阶确定性 ODE 积分器的基础上结合了一个显式的再添加和移除噪声,作者将这个过程形象地称为“搅拌(churn)” 。具体来说,在第 i i i 步,样本 x i \mathbf{x}_i xi 的噪声水平 t i t_i ti(由于 σ = t \sigma=t σ=t,所以其实也就是 σ ( t i ) \sigma(t_i) σ(ti) 了),我们执行两个子步骤。首先,根据一个因子 γ i ≥ 0 \gamma_i\ge 0 γi≥0 向样本添加噪声,先达到更高的噪声水平 t ^ i = t i + γ i t i \hat{t}_i=t_i+\gamma_it_i t^i=ti+γiti。第二步,基于加过噪声的样本 x i + 1 \mathbf{x}_{i+1} xi+1,我们进行一步 ODE 求解,从 t ^ i \hat{t}_i t^i 回到 t i + 1 t_{i+1} ti+1。此时我们就得到了噪声强度为 t i + 1 t_{i+1} ti+1 下的样本 x i + 1 \mathbf{x}_{i+1} xi+1。然后再进行下一步的迭代。
引入随机性有助于纠正早期采样步骤中产生的错误,但它也有自身的缺陷。作者观察到过度的噪声添加和移除过程会导致生图结果逐渐丢失细节。在噪声水平特别高和特别低时,还会出现颜色过饱和的漂移。作者猜测这是因为实际的去噪网络 D θ D_\theta Dθ 在估计噪声时诱导了一个轻微的非守恒向量场,破坏了朗之万扩散的前提条件,从而导致了这些问题。
如果这种退化是由网络本身 D θ ( x ; σ ) D_\theta(\mathbf{x};\sigma) Dθ(x;σ) 的缺陷引起的,那么只能在采样过程中根据经验来补救。具体来说,作者限制仅在特定的噪声水平范围 [ S tmin , S tmax ] [S_\text{tmin}, S_\text{tmax}] [Stmin,Stmax] 内引入随机性,来解决向过饱和颜色的漂移;定义 γ i = S churn / N \gamma_i=S_\text{churn}/N γi=Schurn/N,其中 S churn S_\text{churn} Schurn 用来控制引入控制随机性的总量。并且限制 γ i \gamma_i γi 的值,以确保引入的新噪声不会超过图像中已有的噪声。此外,作者还发现通过设置 S noise S_\text{noise} Snoise 略大于 1,可以部分抵消除去细节的损失,以增加新添加噪声的标准差。这表明猜测的 D θ ( x ; σ ) D_\theta(\mathbf{x};\sigma) Dθ(x;σ) 导致的非守恒性的主要组成部分是倾向于去除稍多的噪声,这很可能是由于回归均值的倾向,曾有研究指出这是任何 L2 训练的去噪器都可能发生的情况。
模型训练
上面,我们通过引入 preconditioning c skip , c out , c in , c noise c_\text{skip},c_\text{out},c_\text{in},c_\text{noise} cskip,cout,cin,cnoise,将扩散模型的训练过程重新形式化为式 (6)。现在,我们要根据神经网络训练过程中的一些经验,确定下这几个设计选项。
首先,我们基于式 6 将损失重写为关于
F
θ
F_\theta
Fθ 的形式:
E
σ
,
y
,
n
[
λ
(
σ
)
c
out
2
(
σ
)
⏟
权重
∣
∣
F
θ
(
c
in
(
σ
)
⋅
(
y
+
n
)
;
c
noise
(
σ
)
)
⏟
网络输出
−
1
c
out
(
y
−
c
skip
(
σ
)
⋅
(
y
+
n
)
∣
∣
2
2
⏟
训练目标
]
(10)
\mathbb{E}_{\sigma,\mathbf{y},\mathbf{n}}\left[\underbrace{\lambda(\sigma)c_\text{out}^2(\sigma)}_{权重}||\underbrace{F_\theta(c_\text{in}(\sigma)\cdot(\mathbf{y}+\mathbf{n});c_\text{noise}(\sigma))}_{网络输出}-\underbrace{\frac{1}{c_\text{out}}(\mathbf{y}-c_\text{skip}(\sigma)\cdot(\mathbf{y}+\mathbf{n})||_2^2}_{训练目标}\right] \tag{10}
Eσ,y,n
权重
λ(σ)cout2(σ)∣∣网络输出
Fθ(cin(σ)⋅(y+n);cnoise(σ))−训练目标
cout1(y−cskip(σ)⋅(y+n)∣∣22
(10)
写成这样的形式之后,我们就能根据输入输出具有单位方差、尽量少地放大网络
F
θ
F_\theta
Fθ 的偏差、权重在不同噪声层级之间尽可能均匀分布这几条网络训练的经验,来分别推导出
c
in
,
c
out
c_\text{in},c_\text{out}
cin,cout 、
c
skip
c_\text{skip}
cskip 和
λ
(
σ
)
\lambda(\sigma)
λ(σ)。至于
c
noise
c_\text{noise}
cnoise,作者是通过实验确定下来的。具体来说,我们希望
1 网络输入具有单位方差,即:
Var
y
,
n
[
c
in
(
σ
)
(
y
+
n
)
]
=
1
c
in
2
(
σ
)
Var
(
y
+
n
)
=
1
c
in
2
(
σ
)
(
σ
data
2
+
σ
2
)
=
1
c
in
(
σ
)
=
1
σ
data
2
+
σ
2
\begin{aligned} \text{Var}_{\mathbf{y},\mathbf{n}}[c_\text{in}(\sigma)(\mathbf{y}+\mathbf{n})]&=1 \\ c_\text{in}^2(\sigma)\text{Var}(\mathbf{y}+\mathbf{n})&=1 \\ c_\text{in}^2(\sigma)(\sigma_\text{data}^2+\sigma^2)&=1 \\ c_\text{in}(\sigma) &=\frac{1}{\sigma_\text{data}^2+\sigma^2} \end{aligned} \notag
Vary,n[cin(σ)(y+n)]cin2(σ)Var(y+n)cin2(σ)(σdata2+σ2)cin(σ)=1=1=1=σdata2+σ21
2 训练目标具有单位方差,即:
Var
y
,
n
[
1
c
out
(
y
−
c
skip
(
σ
)
⋅
(
y
+
n
)
)
]
=
1
1
c
out
2
(
σ
)
Var
y
,
n
[
y
−
c
skip
(
σ
)
⋅
(
y
+
n
)
]
=
1
c
out
2
(
σ
)
=
Var
y
,
n
[
(
1
−
c
skip
(
σ
)
)
y
+
c
skip
(
σ
)
n
]
c
out
2
(
σ
)
=
(
1
−
c
skip
(
σ
)
)
2
σ
data
2
+
c
skip
2
(
σ
)
σ
2
\begin{aligned} \text{Var}_{\mathbf{y},\mathbf{n}}\left[\frac{1}{c_\text{out}}(\mathbf{y}-c_\text{skip}(\sigma)\cdot(\mathbf{y}+\mathbf{n}))\right]&=1 \\ \frac{1}{c^2_\text{out}(\sigma)}\text{Var}_{\mathbf{y},\mathbf{n}}[\mathbf{y}-c_\text{skip}(\sigma)\cdot(\mathbf{y}+\mathbf{n})]&=1 \\ c_\text{out}^2(\sigma)&=\text{Var}_{\mathbf{y},\mathbf{n}}[(1-c_\text{skip}(\sigma))\mathbf{y}+c_\text{skip}(\sigma)\mathbf{n}] \\ c_\text{out}^2(\sigma)&=(1-c_\text{skip}(\sigma))^2\sigma_\text{data}^2+c_\text{skip}^2(\sigma)\sigma^2 \end{aligned} \notag
Vary,n[cout1(y−cskip(σ)⋅(y+n))]cout2(σ)1Vary,n[y−cskip(σ)⋅(y+n)]cout2(σ)cout2(σ)=1=1=Vary,n[(1−cskip(σ))y+cskip(σ)n]=(1−cskip(σ))2σdata2+cskip2(σ)σ2
3 我们想要尽可能减小网络的误差被放大的程度,为此,我们需要使得
c
out
(
σ
)
c_\text{out}(\sigma)
cout(σ) 最小的
c
skip
(
σ
)
c_\text{skip}(\sigma)
cskip(σ),即:
c
skip
=
arg
min
c
skip
(
σ
)
c
out
(
σ
)
c_\text{skip}=\arg\min_{c_\text{skip}(\sigma)}c_\text{out}(\sigma)
cskip=argmincskip(σ)cout(σ),由于
c
out
≥
0
c_\text{out}\ge 0
cout≥0,因此可以等价地写为
c
skip
=
arg
min
c
skip
(
σ
)
c
out
2
(
σ
)
c_\text{skip}=\arg\min_{c_\text{skip}(\sigma)}c^2_\text{out}(\sigma)
cskip=argmincskip(σ)cout2(σ)。这是一个凸优化问题,当关于
c
skip
c_\text{skip}
cskip 的导数为 0 时取得唯一解,再将上面的关系带入,有:
d
[
c
out
2
(
σ
)
]
/
d
c
skip
(
σ
)
=
0
d
[
(
1
−
c
skip
(
σ
)
)
2
σ
data
2
+
c
skip
2
(
σ
)
σ
2
]
/
d
c
skip
(
σ
)
=
0
σ
data
2
d
[
(
1
−
c
skip
(
σ
)
)
2
]
/
d
c
skip
(
σ
)
+
σ
2
d
[
c
skip
2
(
σ
)
]
/
d
c
skip
(
σ
)
=
0
σ
data
2
[
2
c
skip
(
σ
)
−
2
]
+
σ
2
[
2
c
skip
(
σ
)
]
=
0
(
σ
2
+
σ
data
2
)
c
skip
)
(
σ
)
−
σ
data
2
=
0
c
skip
(
σ
)
=
σ
data
2
σ
2
+
σ
data
2
\begin{aligned} d[c_\text{out}^2(\sigma)]/dc_\text{skip}(\sigma)&=0 \\ d[(1-c_\text{skip}(\sigma))^2\sigma_\text{data}^2+c_\text{skip}^2(\sigma)\sigma^2]/dc_\text{skip}(\sigma) &=0\\ {\sigma^2_\text{data}d[(1-c_\text{skip}(\sigma))^2]}/{dc_\text{skip}(\sigma)}+\sigma^2d[c^2_\text{skip}(\sigma)]/dc_\text{skip}(\sigma)&=0 \\ \sigma_\text{data}^2 [2c_\text{skip}(\sigma)-2]+\sigma^2[2c_\text{skip}(\sigma)]&=0 \\ (\sigma^2+\sigma_\text{data}^2)c_\text{skip})(\sigma)-\sigma_\text{data}^2 &=0\\ c_\text{skip}(\sigma)&=\frac{\sigma_\text{data}^2}{\sigma^2+\sigma_\text{data}^2} \end{aligned} \notag
d[cout2(σ)]/dcskip(σ)d[(1−cskip(σ))2σdata2+cskip2(σ)σ2]/dcskip(σ)σdata2d[(1−cskip(σ))2]/dcskip(σ)+σ2d[cskip2(σ)]/dcskip(σ)σdata2[2cskip(σ)−2]+σ2[2cskip(σ)](σ2+σdata2)cskip)(σ)−σdata2cskip(σ)=0=0=0=0=0=σ2+σdata2σdata2
再将这个结果代回,可以求出
c
out
(
σ
)
c_\text{out}(\sigma)
cout(σ):
c
out
(
σ
)
=
σ
⋅
σ
data
/
σ
2
+
σ
data
2
c_\text{out}(\sigma)=\sigma\cdot\sigma_\text{data}/\sqrt{\sigma^2+\sigma_\text{data}^2} \notag
cout(σ)=σ⋅σdata/σ2+σdata2
4 我们希望损失权重在不同的噪声层级之间是均匀分布的,即:
w
(
σ
)
=
λ
(
σ
)
c
out
2
(
σ
)
=
1
.
.
.
λ
(
σ
)
=
(
σ
2
+
σ
data
2
)
/
(
σ
⋅
σ
data
)
2
w(\sigma)=\lambda(\sigma)c_\text{out}^2(\sigma)=1 \\ ... \\ \lambda(\sigma)=(\sigma^2+\sigma_\text{data}^2)/(\sigma\cdot\sigma_\text{data})^2 \notag
w(σ)=λ(σ)cout2(σ)=1...λ(σ)=(σ2+σdata2)/(σ⋅σdata)2
5 我们还想通过调整训练时采样
σ
\sigma
σ 的策略
p
train
(
σ
)
p_\text{train}(\sigma)
ptrain(σ),尽可能高效地进行训练。作者进行了相关的实验,发现只有在中等强度的噪声处,损失才可能降得很小。作者认为,在噪声强度较低时,识别出图像中的微弱噪声很困难,而且对最终的出图结果也没有特别大的帮助;而在噪声强度较高时,训练目标又与数据集平均的正确结果非常不同。因此,在中等噪声强度处的去噪能力是最有意义的,需要尽可能多的将
σ
\sigma
σ 采样在这个区间内。为此,作者采用了一种简单地 log-normal 分布来用作
p
train
(
σ
)
p_\text{train}(\sigma)
ptrain(σ)。
至此,我们已经将训练阶段作者推荐的配置全部或推导,或分析,或实验出来了:
preconditioning:
{ c skip ( σ ) = σ data 2 / ( σ 2 + σ data 2 ) c out ( σ ) = σ ⋅ σ data / σ data 2 + σ 2 c in ( σ ) = 1 / σ data 2 + σ 2 c noise ( σ ) = 1 4 ln ( σ ) \begin{aligned} \begin{cases} c_\text{skip}(\sigma)&=\sigma_\text{data}^2/(\sigma^2+\sigma_\text{data}^2) \\ c_\text{out}(\sigma)&=\sigma\cdot\sigma_\text{data}/\sqrt{\sigma_\text{data}^2+\sigma^2} \\ c_\text{in}(\sigma)&=1/\sqrt{\sigma_\text{data}^2+\sigma^2} \\ c_\text{noise}(\sigma)&=\frac{1}{4}\ln(\sigma) \end{cases} \end{aligned} \notag ⎩ ⎨ ⎧cskip(σ)cout(σ)cin(σ)cnoise(σ)=σdata2/(σ2+σdata2)=σ⋅σdata/σdata2+σ2=1/σdata2+σ2=41ln(σ)
加权:
λ
(
σ
)
=
(
σ
2
+
σ
data
2
)
/
(
σ
⋅
σ
data
)
2
\lambda(\sigma)=(\sigma^2+\sigma_\text{data}^2)/(\sigma\cdot\sigma_\text{data})^2 \notag
λ(σ)=(σ2+σdata2)/(σ⋅σdata)2
采样:
ln
(
σ
)
∼
N
(
P
mean
,
P
std
2
)
,
P
mean
=
−
1.2
,
P
std
=
1.2
\ln(\sigma)\sim\mathcal{N}(P_\text{mean},P_\text{std}^2),\quad P_\text{mean}=-1.2,\ P_\text{std}=1.2 \notag
ln(σ)∼N(Pmean,Pstd2),Pmean=−1.2, Pstd=1.2
另外,在训练时,为了避免在小数据集上出现过拟合的现象,作者还引入了数据扩增管线,并将数据扩增方法作为额外的条件输入给模型以避免泄露。
总结
EDM 是扩散模型发展中非常重要、非常具有实际意义的一篇工作。它构建了一个清晰、统一的扩散模型框架,将之前的几种经典方法都囊括其中,在这个框架下,不同的扩散模型设计选项相互解耦,并且对训练、采样各过程的设计选项进行了清晰的拆分。随后,针对各个设计选项,或通过详尽的理论推导,或通过扎实的实验分析,得出了最优的扩散模型设计形式和典型值。从各大扩散模型代码库中带有 “Karras” 字样的采样和训练设计选项可以得知,这些结果直到今天都还非常有价值。
不过毕竟实验只是做在 32、64 这种小数据集上的,模型规模也不大,笔者比较好奇的是这些结论在高分辨率图像、大数据大模型上的应用价值如何呢?