Improved Denoising Diffusion Probabilistic Models
TL; DR:iDDPM 分析了 DDPM 形式化和训练过程的一些不足,并提出了可学习方差、余弦噪声计划、非均匀的时间步采样策略等多项改进。
前置知识
本文是针对 DDPM 的改进,首先来回顾一下 DDPM 的细节。
定义
给定数据分布
x
0
∼
q
(
x
0
)
x_0\sim q(x_0)
x0∼q(x0) ,我们通过一个联合分布
q
(
x
1
,
…
,
x
T
)
q(x_1,\dots,x_T)
q(x1,…,xT) 来定义前向加噪过程,在每个时间步
t
t
t 像原始数据中添加方差为
β
t
∈
(
0
,
1
)
\beta_t\in(0,1)
βt∈(0,1) 的高斯噪声,得到一系列隐变量
x
1
,
…
,
x
T
x_1,\dots,x_T
x1,…,xT:
q
(
x
1
,
…
,
x
T
)
:
=
∏
t
=
1
T
q
(
x
t
∣
x
t
−
1
)
q(x_1,\dots,x_T):=\prod_{t=1}^Tq(x_t|x_{t-1})\\
q(x1,…,xT):=t=1∏Tq(xt∣xt−1)
q ( x t ∣ x t − 1 ) : = N ( x t ; 1 − β t x t − 1 , β t I ) q(x_t|x_{t-1}):=\mathcal{N} (x_t;\sqrt{1-\beta_t}x_{t-1},\beta_t\mathbf{I}) q(xt∣xt−1):=N(xt;1−βtxt−1,βtI)
当
T
T
T 足够大,且
β
t
\beta_t
βt 设计合理,
x
T
x_T
xT 就几乎是一个纯高斯噪声。而如果我们有了反向的条件概率分布
q
(
x
t
−
1
∣
x
t
)
q(x_{t-1}|x_t)
q(xt−1∣xt) ,我们就能随机采样一个高斯噪声
x
T
∼
N
(
0
,
I
)
x_T\sim\mathcal{N}(0,\mathbf{I})
xT∼N(0,I) 然后反向一步步去噪,最终得到
x
0
∼
q
(
x
0
)
x_0\sim q(x_0)
x0∼q(x0) 。然而,
q
(
x
t
−
1
∣
x
t
)
q(x_{t-1}|x_t)
q(xt−1∣xt) 与整个数据分布有关,无法直接得到,因此我们训练一个神经网络来近似它:
p
θ
(
x
t
−
1
∣
x
t
)
:
=
N
(
x
t
−
1
;
μ
θ
(
x
t
,
t
)
,
Σ
θ
(
x
t
,
t
)
)
p_\theta(x_{t-1}|x_t):=\mathcal{N}(x_{t-1};\mu_\theta(x_t,t),\Sigma_\theta(x_t,t))
pθ(xt−1∣xt):=N(xt−1;μθ(xt,t),Σθ(xt,t))
其实这里的
q
q
q 和
p
p
p 合起来看就是一个变分自编码器 VAE,变分下界(VLB)可以写作:
L
vlb
:
=
L
0
+
L
1
+
⋯
+
L
T
−
1
+
L
T
L
0
:
=
−
log
p
θ
(
x
0
∣
x
1
)
L
t
−
1
:
=
D
K
L
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
∣
∣
p
θ
(
x
t
−
1
∣
x
t
)
)
L
T
:
=
D
K
L
(
q
(
x
T
∣
x
0
)
∣
∣
p
(
x
T
)
)
L_\text{vlb}:=L_0+L_1+\dots+L_{T-1}+L_T\\ L_0:=-\log p_\theta(x_0|x_1)\\ L_{t-1}:=D_{KL}(q(x_{t-1}|x_t,x_0)||p_\theta(x_{t-1}|x_t))\\ L_T:=D_{KL}(q(x_T|x_0)||p(x_T))
Lvlb:=L0+L1+⋯+LT−1+LTL0:=−logpθ(x0∣x1)Lt−1:=DKL(q(xt−1∣xt,x0)∣∣pθ(xt−1∣xt))LT:=DKL(q(xT∣x0)∣∣p(xT))
除了
L
0
L_0
L0,上式中的每一项都是两个高斯分布之间的 KL 散度,这是完全可以解析计算出来的。对于
L
0
L_0
L0 ,我们假设图片的每个色彩元素有 256 种可能,并计算
p
θ
(
x
0
∣
x
1
)
p_\theta(x_0|x_1)
pθ(x0∣x1) 落在正确可能性的概率,这就可以通过高斯分布的 CDF 计算出来。而
L
T
L_T
LT 与参数
θ
\theta
θ 无关,如果前向过程足够充分地破坏了原始数据分布,即有
q
(
x
T
∣
x
0
)
≈
N
(
0
,
I
)
q(x_T|x_0)\approx\mathcal{N}(0,\mathbf{I})
q(xT∣x0)≈N(0,I),那么其损失值就接近 0 了。
DDPM 中还指出,式 (2) 所定义的前向加噪过程中是相邻单步的加噪,实际上从原始数据
x
0
x_0
x0 开始,可以直接完成任意多步加噪。记
α
t
:
=
1
−
β
t
,
α
ˉ
t
:
=
∏
s
=
0
t
α
s
\alpha_t:=1-\beta_t,\ \ \bar\alpha_t:=\prod_{s=0}^t\alpha_s
αt:=1−βt, αˉt:=∏s=0tαs ,边缘分布可重写为:
q
(
x
t
∣
x
0
)
=
N
(
x
t
;
α
ˉ
t
x
0
+
(
1
−
α
ˉ
t
)
I
)
q(x_t|x_0)=\mathcal{N}(x_t;\sqrt{\bar\alpha_t}x_0+(1-\bar\alpha_t)\mathbf{I})
q(xt∣x0)=N(xt;αˉtx0+(1−αˉt)I)
x t = α ˉ t x 0 + 1 − α ˉ t ϵ , ϵ ∼ N ( 0 , I ) x_t=\sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}\epsilon,\ \ \epsilon\sim\mathcal{N}(0,\mathbf{I}) xt=αˉtx0+1−αˉtϵ, ϵ∼N(0,I)
这里的 1 − α ˉ t 1-\bar\alpha_t 1−αˉt 是任意步加噪的方差,我们其实也可以等价地用 α ˉ t \bar\alpha_t αˉt 来代替 β t \beta_t βt 定义 noise schedule。
根据贝叶斯定理,我们可以定义
β
~
t
\tilde\beta_t
β~t 和
μ
~
t
(
x
t
,
x
0
)
\tilde\mu_t(x_t,x_0)
μ~t(xt,x0) 来计算后验
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(x_{t-1}|x_t,x_0)
q(xt−1∣xt,x0) :
β
~
t
:
=
1
−
α
ˉ
t
−
1
1
−
α
ˉ
t
β
t
μ
~
t
(
x
t
,
x
0
)
:
=
α
ˉ
t
−
1
β
t
1
−
α
ˉ
t
x
0
+
α
t
(
1
−
α
ˉ
t
−
1
)
1
−
α
ˉ
t
x
t
q
(
x
t
−
1
∣
x
t
,
x
0
)
=
N
(
x
t
−
1
;
μ
~
(
x
t
,
x
0
)
,
β
~
t
I
)
\tilde\beta_t:=\frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\beta_t\\ \tilde\mu_t(x_t,x_0):=\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}x_0+\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_t}x_t\\ q(x_{t-1}|x_t,x_0)=\mathcal{N}(x_{t-1};\tilde\mu(x_t,x_0),\tilde\beta_t\mathbf{I})
β~t:=1−αˉt1−αˉt−1βtμ~t(xt,x0):=1−αˉtαˉt−1βtx0+1−αˉtαt(1−αˉt−1)xtq(xt−1∣xt,x0)=N(xt−1;μ~(xt,x0),β~tI)
训练细节
训练的目标函数(式 4)是多个独立项 L t − 1 L_{t-1} Lt−1 的加和,而式 (5) 则提供了高效进行多步加噪的方式,从而我们就可以结合 (3) 式先验和 (7) 式后验来估计 L t − 1 L_{t-1} Lt−1 。DDPM 训练时,均匀地采样时间步 t t t,并使用期望 E t , x 0 , ϵ [ L t − 1 ] \mathbb{E}_{t,x_0,\epsilon}[L_{t-1}] Et,x0,ϵ[Lt−1] 来估计 L vlb L_{\text{vlb}} Lvlb 。
要训练一个扩散模型,我们需要选择具体的参数化形式,即:让模型学习预测什么。(当时已经)有许多种对
μ
θ
(
x
t
,
t
)
\mu_\theta(x_t,t)
μθ(xt,t) 进行参数化的方式,最显然的做法是使用神经网络直接预测
μ
θ
(
x
t
,
t
)
\mu_\theta(x_t,t)
μθ(xt,t);也可以使用神经网络预测
x
0
x_0
x0,再根据式 (7-2) 计算出
μ
θ
(
x
t
,
t
)
\mu_\theta(x_t,t)
μθ(xt,t) ;还可以训练神经网络预测噪声
ϵ
\epsilon
ϵ 并根据式 (6,7) 推出:
μ
θ
(
x
t
,
t
)
=
1
α
t
(
x
t
−
β
t
1
−
α
ˉ
t
ϵ
θ
(
x
t
,
t
)
)
\mu_\theta(x_t,t)=\frac{1}{\sqrt{\alpha_t}}(x_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon_\theta(x_t,t))
μθ(xt,t)=αt1(xt−1−αˉtβtϵθ(xt,t))
DDPM 中发现,最后这一种预测噪声的参数化方式效果最后,特别是如果搭配上重加权损失:
L
simple
=
E
t
,
x
0
,
ϵ
[
∣
∣
ϵ
−
ϵ
θ
(
x
t
,
t
)
∣
∣
2
2
]
L_\text{simple}=\mathbb{E}_{t,x_0,\epsilon}[||\epsilon-\epsilon_\theta(x_t,t)||_2^2]
Lsimple=Et,x0,ϵ[∣∣ϵ−ϵθ(xt,t)∣∣22]
上式这种目标函数可以看作是
L
vlb
L_\text{vlb}
Lvlb 的一种重加权的形式,这种形式与
Σ
θ
\Sigma_\theta
Σθ 无关了。DDPM 的作者发现,相比于直接优化
L
vlb
L_\text{vlb}
Lvlb,使用这种目标函数训练出模型的采样结果质量特别好,并通过分析与去噪得分模型的关系来解释这一点。
L simple L_\text{simple} Lsimple 这种形式的微妙之处在于其没有学习方差 Σ θ ( x t , t ) \Sigma_\theta(x_t,t) Σθ(xt,t) 。DDPM 中使用固定的方差 σ t 2 I \sigma_t^2\mathbf{I} σt2I 得到了比学习方差更好的结果。并且还发现使用固定方差 σ t 2 = β t \sigma^2_t=\beta_t σt2=βt 或 σ t 2 = β ~ t \sigma^2_t=\tilde\beta_t σt2=β~t 效果差不多。这两个方差实际上分别对应了给定数据分布 q ( x 0 ) q(x_0) q(x0) 分别为各向同性的高斯分布和 delta 函数时所对应的方差的上界和下界。
优化Log Likelihood
DDPM 生成结果的感知质量(FID 和 IS 指标)非常好,但是其 log likelihood 却不如其他模型。一般认为,优化 Log likelihood 可以提升生成模型捕获数据分布中所有模式的能力,因此其是生成模型的重要指标。并且有工作指出,log likelihood 的一点提升就对采样质量和特征表示能力产生巨大影响。因此,本文要探究为什么 DDPM 在 log likelihood 这项指标上表现不佳以及如何改进,达到更高的 log likelihood。
可学习的方差
DDPM 中,作者设定了固定的方差 Σ θ ( x t , t ) = σ t 2 I \Sigma_\theta(x_t,t)=\sigma^2_t\mathbf{I} Σθ(xt,t)=σt2I ,其中 σ t \sigma_t σt 是不可学习的超参数。奇怪的是,他们发现固定方差为 σ t 2 = β t \sigma_t^2=\beta_t σt2=βt 或 σ t 2 = β ~ t \sigma_t^2=\tilde\beta_t σt2=β~t 结果都差不多,而这两种方差实际上是两种极端(方差的上下界),让人不禁想问问这是为什么。
作者给出了自己的分析。下图是 β ~ t / β t \tilde\beta_t/\beta_t β~t/βt 随着时间步推移的变化(DDPM 中 T = 1000 T=1000 T=1000,本文中 T = 4000 T=4000 T=4000)。可以看到,除了在 t = 0 t=0 t=0 附近,二者的比值很快就收敛到 1 了,也就是说,除了在加噪前期,二者在大部分时间下都是相等的。而在加噪的前期,模型处理的都是一些不易感知的细节。还可以看到,随着我们增加更多的总时间步数,二者就更早地收敛到相等了。因此方差选择固定为哪个,对生成结果的感知质量影响不大,并且随着训练总步数的增加, μ θ ( x t , t ) \mu_\theta(x_t,t) μθ(xt,t) 的影响比 Σ θ ( x t , t ) \Sigma_\theta(x_t,t) Σθ(xt,t) 的影响更大。
当然,上述分析说明了固定方差对采样质量的影响不大,但是对 log likelihood 可不是这样。下图展示了 VLB 损失项(对数尺度)随着时间步的变化,可以看到,最前的几步 VLB 损失项下降的幅度最大,也就是说扩散过程的前期对 VLB 损失项的影响最大,而这恰好是不同方差存在差异的部分,因此一个更好的方差策略可能对 log likelihood 的提升很关键。
图 1 还展示了,合理的
Σ
θ
(
x
t
,
t
)
\Sigma_\theta(x_t,t)
Σθ(xt,t) 的范围很小,即使是在对数尺度上,神经网络直接预测
Σ
θ
(
x
t
,
t
)
\Sigma_\theta(x_t,t)
Σθ(xt,t) 也很困难。因此,作者提出将对方差的预测参数化为预测
β
t
\beta_t
βt 和
β
~
t
\tilde\beta_t
β~t 之间的插值。具体来说,神经网络输出一个向量
v
v
v ,然后通过插值求出方差:
Σ
θ
(
x
t
,
t
)
=
exp
(
v
log
β
t
+
(
1
−
v
)
log
β
~
t
)
\Sigma_\theta(x_t,t)=\exp(v\log\beta_t+(1-v)\log\tilde\beta_t)
Σθ(xt,t)=exp(vlogβt+(1−v)logβ~t)
作者没有对
v
v
v 的数值施加任何限制,也就是理论上来说方差的值有可能超出内插的范围,但实际实验中并没有观察到这种现象。所以原本
Σ
θ
(
x
t
,
t
)
\Sigma_\theta(x_t,t)
Σθ(xt,t) 的上下界确实是已经足够。
由于我们要学习方差 ,而 DDPM 的
L
simple
L_\text{simple}
Lsimple 并不包含
Σ
θ
(
x
t
,
t
)
\Sigma_\theta(x_t,t)
Σθ(xt,t)。因此我们定义一个混合的损失:
L
hybrid
=
L
simple
+
λ
L
vlb
L_\text{hybrid}=L_\text{simple}+\lambda L_\text{vlb}
Lhybrid=Lsimple+λLvlb
本文中设
λ
=
0.001
\lambda=0.001
λ=0.001 ,同时对
μ
θ
(
x
t
,
t
)
\mu_\theta(x_t,t)
μθ(xt,t) 的输出关于
L
vlb
L_\text{vlb}
Lvlb 施加梯度停止(stop gradient),从而
L
vlb
L_\text{vlb}
Lvlb 可以引导对方差
Σ
θ
(
x
t
,
t
)
\Sigma_\theta(x_t,t)
Σθ(xt,t) 的学习,而方差
μ
θ
(
x
t
,
t
)
\mu_\theta(x_t,t)
μθ(xt,t) 仍旧主要受
L
simple
L_\text{simple}
Lsimple 影响。
改进 noise schedule
作者指出 DDPM 中的 linear noise schedule 仅适合高分辨率图像,而在低分辨率(64x64、32x32 等)上表现不是最好。具体原因是 linear noise schedule 在前向加噪的后期噪声过大,对采样质量的提升帮助不大。下图展示了使用不同的 noise schedule 时,采样质量随跳过的采样步数的关系。可以看到,使用 linear schedule 时,即使跳过三成的采样步数,采样质量也没有受到多少影响,证明了上述 linear schedule 加噪后期对采样质量提升帮助不大的观点。
为了解决这个问题,从
α
ˉ
t
\bar\alpha_t
αˉt 的角度,作者定义了一种新的 cosine noise schedule:
α
ˉ
t
=
f
(
t
)
f
(
0
)
,
f
(
t
)
=
cos
(
t
/
T
+
s
1
+
s
⋅
π
2
)
2
\bar\alpha_t=\frac{f(t)}{f(0)},\ \ f(t)=\cos(\frac{t/T+s}{1+s}\cdot\frac{\pi}{2})^2
αˉt=f(0)f(t), f(t)=cos(1+st/T+s⋅2π)2
注意经过
β
t
=
1
−
α
ˉ
t
α
ˉ
t
−
1
\beta_t=1-\frac{\bar\alpha_t}{\bar\alpha_{t-1}}
βt=1−αˉt−1αˉt ,就转换成
β
t
\beta_t
βt 的 noise schedule 定义方式了。在实际中,还对
β
t
\beta_t
βt 进行了截断,使其不高于 0.999,从而避免在
t
=
T
t=T
t=T 附近出现奇点。
作者提出 cosine noise schedule 的设计理念是期望 α ˉ t \bar\alpha_t αˉt 在扩散过程中按照接近线性的趋势下降。并且在开始 t = 0 t=0 t=0 和结束 t = T t=T t=T 附近较为平缓。下图展示了 DDPM 中的 linear noise schedule 和本文的 cosine noise schedule 的 α ˉ t \bar\alpha_t αˉt 值的变化趋势。可以看到,linear noise schedule 下的 α ˉ t \bar\alpha_t αˉt 在扩散过程后期过快地收敛到 0 了,导致噪声过大,采样效率低,而 cosine noise schedule 则达成了上述想要的趋势。这里选用余弦是因为它是一种比较常见的能够满足作者设计理念的函数曲线,其他的能够符合上述两点的函数也可行。
作者还还发现扩散过程前期噪声太小的话,网络预测噪声 ϵ \epsilon ϵ 会比较困难,因此增加了一个小的偏移量 s s s,避免在 t = 0 t=0 t=0 附近噪声过小。具体的 s s s 值考虑到 β 0 \sqrt{\beta_0} β0 需要比像素尺寸 1 / 127.5 1/127.5 1/127.5 稍大,因此取 s = 0.008 s=0.008 s=0.008 。
稳定VLB训练
作者还期望直接优化 L vlb L_\text{vlb} Lvlb ,而不是优化 L hybrid L_\text{hybrid} Lhybrid ,从而达到更高的 log likelihood。但作者在实验中发现 L vlb L_\text{vlb} Lvlb 很难直接优化(至少在 ImageNet 64x64 上)。下图展示了不同的损失函数在训练时的 loss 下降曲线。很明显可以看到, L vlb L_\text{vlb} Lvlb 和 L hybrid L_\text{hybrid} Lhybrid 都比较震荡不稳定,但是 L hybrid L_\text{hybrid} Lhybrid 明显损失收敛的值更低一些。
作者猜测这是由于 L vlb L_\text{vlb} Lvlb 的梯度中噪声比 L hybrid L_\text{hybrid} Lhybrid 更大导致的,并通过分别测试两种训练目标训练触摸型的梯度噪声尺度(如下图所示)验证了这一点。因此,为了直接优化 log likelihood,作者想要找一种方法来降低 L vlb L_\text{vlb} Lvlb 损失的方差。
注意到图 2 中
L
vlb
L_\text{vlb}
Lvlb 在不同的时间步损失项的值差异非常大,作者据此猜测是由于均匀采样
t
t
t 导致训练目标
L
vlb
L_\text{vlb}
Lvlb 的噪声震荡。为解决这个问题,作者提出了一种对时间步
t
t
t 的重要性采样的策略:
L
vlb
=
E
t
∼
p
t
[
L
t
p
t
]
,
where
p
t
∝
E
[
L
t
2
]
and
∑
p
t
=
1
L_\text{vlb}=\mathbb{E}_{t\sim p_t}[\frac{L_t}{p_t}],\ \ \ \text{where}\ p_t\propto \sqrt{\mathbb{E}[L_t^2]}\ \text{and}\ \sum{p_t}=1
Lvlb=Et∼pt[ptLt], where pt∝E[Lt2] and ∑pt=1
其实就是根据训练过程中每个时间步
t
t
t 的损失值
L
t
L_t
Lt 进行重要性加权,对损失值更高的
t
t
t 分配更低的权重,从而使得整体损失噪声波动降低,稳定训练。由于我们事先不知道
E
[
L
t
2
]
\mathbb{E}[L^2_t]
E[Lt2] 的值,作者选择保存前 10 步的 loss 值,然后再训练过程中进行动态更新。在训练开始时,先使用均匀采样,直到每个时间步
t
t
t 都被采样了至少 10 次。从上图可以看到,使用了重要性采样策略的
L
vlb
(
resampled
)
L_\text{vlb}(\text{resampled})
Lvlb(resampled) 明显噪声更低,更稳定,且达到了更好的 log likelihood。
总结
Improved DDPM 是一篇重要的对 DDPM 的改进工作,其分析、解决问题的思路非常清晰,值得学习。iDDPM 提出了可学习方差、余弦噪声计划、非均匀的时间步采样策略等技术,许多对后续扩散模型的改进工作由深远的影响。另外,论文本身的写作也很不错,读起来很顺畅,推荐大家去读原文。