扩散模型加速采样算法《Denoising Diffusion lmplicit Models》论文原理
参考视频
简介
针对DDPM的最大的缺点——依据马尔科夫链过程,需要对图片进行逐步加噪和减噪,导致step步数太多,计算太慢,基于DDPM,DDIM主要有两项改进:
- 对于一个已经训练好的DDPM,只需要对采样公式做简单的修改,模型就能在去噪时跳步骤,在一步去噪迭代中直接预测若干次去噪后的结果。
- DDIM论文推广了DDPM的数学模型,打破了马尔科夫链的过程,从更高的视角定义了DDPM的反向过程(去噪过程)。在这个新数学模型下,我们可以自定义模型的噪声强度,让同一个训练好的DDPM有不同的采样效果。
原理和数学推导
- DDPM的损失函数特点
DDPM的损失函数 L s i m p l e L_{simple} Lsimple只依赖于边缘分布,不直接依赖于联合分布
联合分布以什么形式出现并不会影响训练DDPM,因此可以设计非马尔科夫的扩散过程,并且保证 q ( x t ∣ x 0 ) q(\boldsymbol{x}_t|\boldsymbol{x}_0) q(xt∣x0)是一致的,则可以和DDPM共享同一个目标函数。
DDPM在推导出
L
s
i
m
p
l
e
L_{simple}
Lsimple过程中,并没有用到
q
(
x
1
:
T
∣
x
0
)
q(\boldsymbol{x}_{1:T}|\boldsymbol{x}_0)
q(x1:T∣x0)的具体形式,只是基于贝叶斯公式和
q
(
x
t
∣
x
t
−
1
,
x
0
)
q(x_{t}|x_{t-1},x_{0})
q(xt∣xt−1,x0)、
q
(
x
t
∣
x
0
)
q(\boldsymbol{x}_t|\boldsymbol{x}_0)
q(xt∣x0)表达式。
在训练DDPM所用到的
L
s
i
m
p
l
e
L_{simple}
Lsimple loss中,也没有采用跟
q
(
x
t
∣
x
t
−
1
,
x
0
)
q(x_{t}|x_{t-1},x_{0})
q(xt∣xt−1,x0)相关的系数,而是直接选择将预测噪音的权重设置为1。由于噪音项是来自
q
(
x
t
∣
x
0
)
q(\boldsymbol{x}_t|\boldsymbol{x}_0)
q(xt∣x0)的采样,因此,DDPM的目标函数其实只由
q
(
x
t
∣
x
0
)
q(\boldsymbol{x}_t|\boldsymbol{x}_0)
q(xt∣x0)表达式决定。
因此,使用非马尔可夫性质可以更具有一般性,只要保证 q ( x t ∣ x 0 ) q(\boldsymbol{x}_t|\boldsymbol{x}_0) q(xt∣x0)的形式不变,就可以直接复用训练好的DDPM,在反向过程中使用新的概率分布来进行随机采样。
- 非马尔科夫链的前向扩散过程
作者设计了新的非马尔科夫链的前向扩散过程,定义了相同形式的分布函数来代替DDPM中基于马尔科夫性质的分布函数
q
σ
(
x
1
:
T
∣
x
0
)
:
=
q
σ
(
x
T
∣
x
0
)
∏
t
=
2
T
q
σ
(
x
t
−
1
∣
x
t
,
x
0
)
∗
E
q
.
(
6
)
∗
q_\sigma(\boldsymbol{x}_{1:T}|\boldsymbol{x}_0):=q_\sigma(\boldsymbol{x}_T|\boldsymbol{x}_0)\prod_{t=2}^Tq_\sigma(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,\boldsymbol{x}_0) *Eq.(6)*
qσ(x1:T∣x0):=qσ(xT∣x0)t=2∏Tqσ(xt−1∣xt,x0)∗Eq.(6)∗
这里
q
σ
(
x
T
∣
x
0
)
=
N
(
α
T
x
0
,
(
1
−
α
T
)
I
)
q_\sigma(\boldsymbol{x}_T|\boldsymbol{x}_0)=\mathcal{N}(\sqrt{\alpha_T}\boldsymbol{x}_0,(1-\alpha_T)\boldsymbol{I})
qσ(xT∣x0)=N(αTx0,(1−αT)I)并且当
t
>
1
t>1
t>1时
q
σ
(
x
t
−
1
∣
x
t
,
x
0
)
=
N
(
α
t
−
1
x
0
+
1
−
α
t
−
1
−
σ
t
2
⋅
x
t
−
α
t
x
0
1
−
α
t
,
σ
t
2
I
)
∗
E
q
.
(
7
)
∗
q_{\sigma}(\boldsymbol{x}_{t-1}|\boldsymbol{x}_{t},\boldsymbol{x}_{0})=\mathcal{N}\left(\sqrt{\alpha_{t-1}}\boldsymbol{x}_{0}+\sqrt{1-\alpha_{t-1}-\sigma_{t}^{2}}\cdot\frac{\boldsymbol{x}_{t}-\sqrt{\alpha_{t}}\boldsymbol{x}_{0}}{\sqrt{1-\alpha_{t}}},\sigma_{t}^{2}\boldsymbol{I}\right) *Eq.(7)*
qσ(xt−1∣xt,x0)=N(αt−1x0+1−αt−1−σt2⋅1−αtxt−αtx0,σt2I)∗Eq.(7)∗
σ
∈
R
≥
0
T
\sigma\in\mathbb{R}_{\geq0}^T
σ∈R≥0T是实数向量的超参数。选择均值函数是为了确保对于所有t,都有
q
σ
(
x
t
∣
x
0
)
=
N
(
α
t
x
0
,
(
1
−
α
t
)
I
)
q_{\sigma}(\boldsymbol{x}_{t}|\boldsymbol{x}_{0}) = {\mathcal{N}(\sqrt{\alpha_{t}}\boldsymbol{x}_{0},(1-\alpha_{t})\boldsymbol{I})}
qσ(xt∣x0)=N(αtx0,(1−αt)I),这样它就定义了一个“联合”推理分布,与期望的“边际”相匹配。前向过程可由贝叶斯规则导出:
q
σ
(
x
t
∣
x
t
−
1
,
x
0
)
=
q
σ
(
x
t
−
1
∣
x
t
,
x
0
)
q
σ
(
x
t
∣
x
0
)
q
σ
(
x
t
−
1
∣
x
0
)
q_\sigma(\boldsymbol{x}_t|\boldsymbol{x}_{t-1},\boldsymbol{x}_0)=\frac{q_\sigma(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,\boldsymbol{x}_0)q_\sigma(\boldsymbol{x}_t|\boldsymbol{x}_0)}{q_\sigma(\boldsymbol{x}_{t-1}|\boldsymbol{x}_0)}
qσ(xt∣xt−1,x0)=qσ(xt−1∣x0)qσ(xt−1∣xt,x0)qσ(xt∣x0)
可以证明对于每一个时刻 t t t都能满足 q σ ( x t ∣ x 0 ) = N ( α t x 0 , ( 1 − α t ) I ) q_\sigma(\boldsymbol{x}_t|\boldsymbol{x}_0) = \mathcal{N}(\sqrt{\alpha_t}\boldsymbol{x}_0,(1-\alpha_t)\boldsymbol{I}) qσ(xt∣x0)=N(αtx0,(1−αt)I),这与DDPM每一时刻的分布是相同形式的,因此可以使用相同的目标函数。
证明过程:
与DDPM相比,随着
σ
\sigma
σ的不同,高斯分布的均值和方程不同,重参数化的参数也就不同,进行的随机采样也会不一样。
- 非马尔科夫扩散逆过程的采样
与DDPM类似的过程
与DDPM一样,使用
f
θ
(
t
)
(
x
t
)
:
=
(
x
t
−
1
−
α
t
⋅
ϵ
θ
(
t
)
(
x
t
)
)
/
α
t
f_\theta^{(t)}(\boldsymbol{x}_t):=(\boldsymbol{x}_t-\sqrt{1-\alpha_t}\cdot\epsilon_\theta^{(t)}(\boldsymbol{x}_t))/\sqrt{\alpha_t}
fθ(t)(xt):=(xt−1−αt⋅ϵθ(t)(xt))/αt
将
x
0
{x}_0
x0预测出来,则方向过程
p
θ
(
t
)
(
x
t
−
1
∣
x
t
)
=
{
N
(
f
θ
(
1
)
(
x
1
)
,
σ
1
2
I
)
if
t
=
1
q
σ
(
x
t
−
1
∣
x
t
,
f
θ
(
t
)
(
x
t
)
)
otherwise
,
p_\theta^{(t)}(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t)=\begin{cases}\mathcal{N}(f_\theta^{(1)}(\boldsymbol{x}_1),\sigma_1^2\boldsymbol{I})&\text{if} t=1\\q_\sigma(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,f_\theta^{(t)}(\boldsymbol{x}_t))&\text{otherwise},\end{cases}
pθ(t)(xt−1∣xt)={N(fθ(1)(x1),σ12I)qσ(xt−1∣xt,fθ(t)(xt))ift=1otherwise,
可以证明
J
σ
=
L
γ
+
C
J_\sigma=L_\gamma+C
Jσ=Lγ+C,并且优化
J
σ
J_\sigma
Jσ就是优化了
L
γ
L_\gamma
Lγ。
证明过程:
特殊的采样
将DDPM中的
P
(
x
t
−
1
∣
x
t
,
x
0
)
∼
N
(
a
t
(
1
−
a
‾
t
−
1
)
1
−
a
‾
t
x
t
+
a
‾
t
−
1
(
1
−
a
t
)
1
−
a
‾
t
×
x
t
−
1
−
a
‾
t
×
ϵ
a
‾
t
,
(
β
t
(
1
−
a
‾
t
−
1
)
1
−
a
‾
t
)
2
)
P(x_{t-1}|x_t,x_0)\sim N\left(\frac{\sqrt{a_t}(1-\overline{a}_{t-1})}{1-\overline{a}_t}x_t+\frac{\sqrt{\overline{a}_{t-1}}(1-a_t)}{1-\overline{a}_t}\times\frac{x_t-\sqrt{1-\overline{a}_t}\times\epsilon}{\sqrt{\overline{a}_t}},\left(\sqrt{\frac{\beta_t(1-\overline{a}_{t-1})}{1-\overline{a}_t}}\right)^2\right)
P(xt−1∣xt,x0)∼N
1−atat(1−at−1)xt+1−atat−1(1−at)×atxt−1−at×ϵ,
1−atβt(1−at−1)
2
化为
x
t
−
1
=
α
t
−
1
(
x
t
−
1
−
α
t
ϵ
θ
(
t
)
(
x
t
)
α
t
)
⏟
“predicted
x
0
”
+
1
−
α
t
−
1
−
σ
t
2
⋅
ϵ
θ
(
t
)
(
x
t
)
⏟
“direction pointing to
x
t
”
+
σ
t
ϵ
t
⏟
random noise
x_{t-1}=\sqrt{\alpha_{t-1}}\underbrace{\left(\frac{\boldsymbol{x}_t-\sqrt{1-\alpha_t}\epsilon_\theta^{(t)}(\boldsymbol{x}_t)}{\sqrt{\alpha_t}}\right)}_{\text{“predicted }\boldsymbol{x}_0\text{”}}+\underbrace{\sqrt{1-\alpha_{t-1}-\sigma_t^2}\cdot\boldsymbol{\epsilon}_\theta^{(t)}(\boldsymbol{x}_t)}_{\text{“direction pointing to }\boldsymbol{x}_t\text{”}}+\underbrace{\sigma_t\epsilon_t}_{\text{random noise}}
xt−1=αt−1“predicted x0”
(αtxt−1−αtϵθ(t)(xt))+“direction pointing to xt”
1−αt−1−σt2⋅ϵθ(t)(xt)+random noise
σtϵt
分割出了
σ
t
ϵ
t
\sigma_t\epsilon_t
σtϵt一个随机噪声项。不同的
σ
\sigma
σ,高斯分布的均值和方程不同,重参数化的参数也就不同,进行的随机采样也会不一样。但是对于模型来说都是预测
ϵ
θ
(
t
)
(
x
t
)
\epsilon_\theta^{(t)}(x_t)
ϵθ(t)(xt),模型的目标函数跟
σ
\sigma
σ是无关的,
σ
\sigma
σ只影响采样。当
σ
t
=
(
1
−
α
t
−
1
)
/
(
1
−
α
t
)
1
−
α
t
/
α
t
−
1
\sigma_{t} = \sqrt{(1-\alpha_{t-1})/(1-\alpha_{t})}\sqrt{1-\alpha_{t}/\alpha_{t-1}}
σt=(1−αt−1)/(1−αt)1−αt/αt−1时,前向过程便变成了马尔科夫性质,也就退化为DDPM。
当 σ t = 0 \sigma_{t}=0 σt=0时,便丢失了随机噪声项,此时扩散过程是确定的高斯分布的随机采样。因此可以不需要每一个时刻去一步一步执行(因为是确定性的),减少采样次数,只要模型 ϵ θ ( t ) ( x t ) \epsilon_\theta^{(t)}(x_t) ϵθ(t)(xt)能够预测准确就能获得不错的效果。不过, ϵ θ ( t ) ( x t ) \epsilon_\theta^{(t)}(x_t) ϵθ(t)(xt)很难做到一步就能预测准确,因此多次的采样是必要的。
- 采样的特殊性带来的加速采样技巧——respacing
可以选取时刻序列的一个子集
{
x
τ
1
,
…
,
x
τ
S
}
\{x_{\tau_1},\ldots,x_{\tau_S}\}
{xτ1,…,xτS}作为反向过程的采样,令
σ
τ
i
(
η
)
=
η
(
1
−
α
τ
i
−
1
)
/
(
1
−
α
τ
i
)
1
−
α
τ
i
/
α
τ
i
−
1
\sigma_{\tau_{i}}(\eta)=\eta\sqrt{(1-\alpha_{\tau_{i-1}})/(1-\alpha_{\tau_{i}})}\sqrt{1-\alpha_{\tau_{i}}/\alpha_{\tau_{i-1}}}
στi(η)=η(1−ατi−1)/(1−ατi)1−ατi/ατi−1
当
η
=
1
\eta=1
η=1时,此时为DDPM;当
η
=
0
\eta=0
η=0时,
σ
t
=
0
\sigma_{t}=0
σt=0,此时为DDIM。
并且,如果令随机噪音的方差大于1,即令
σ
^
τ
i
=
1
−
α
τ
i
/
α
τ
i
−
1
\hat{\sigma}_{\tau_{i}} = \sqrt{1-\alpha_{\tau_{i}}/\alpha_{\tau_{i-1}}}
σ^τi=1−ατi/ατi−1
得到不同的实验结果:
相关论文和代码下载
有时候论文网站arXiv.org打开比较慢,已经将相关论文和代码上传到网盘,需要的可以自取
链接: https://pan.baidu.com/s/1J1h8R4KyY7k6NgS2t7YOZg?pwd=3ss8
可以关注公众号:
搜索:福尔马林灌汤包