Diffusion 公式推导 中对 DDPM 进行了推导,本文接上文对 DDIM 进行推导。
六. 模型改进
从扩散模型的推理过程不难看出,DDPM 有一个致命缺点 —— 推理速度过慢,因为逆扩散是从 x T x_{T} xT 到 x 0 x_{0} x0 的完整过程,无法跳过中间的迭代步骤。为了加快推理过程,DDIM (Denoising Diffusion Implicit Models) 维持 DDPM 的扩散过程保持不变,对其去噪过程进行了改进,采用一个非马尔科夫过程,使得生成过程可以在更少的时间步内完成。下面对其进行推导 1:
DDIM 假设
q
(
x
t
∗
∣
x
0
,
x
t
)
q (x_{t^*} \mid x_0, x_t)
q(xt∗∣x0,xt) 是一个高斯分布(原文中使用的是
t
−
1
t-1
t−1,这里为了突出隔多步采样,使用
t
∗
t^*
t∗ 标记时间步,
0
<
t
∗
<
t
0<t^*<t
0<t∗<t),对其进行待定系数:
q
(
x
t
∗
∣
x
0
,
x
t
)
∼
N
(
k
x
0
+
m
x
t
,
σ
2
)
(21)
q (x_{t^*} \mid x_0, x_t) \sim \mathcal{N}\left(kx_0+mx_t, \sigma^2\right) \tag{21}
q(xt∗∣x0,xt)∼N(kx0+mxt,σ2)(21)
因此有:
x
t
∗
=
k
x
0
+
m
x
t
+
σ
ϵ
其中
ϵ
∼
N
(
0
,
I
)
(22)
x_{t^*} = kx_0+mx_t + \sigma \epsilon \quad \text{ 其中 } \epsilon \sim \mathcal{N}(0, \bold I) \tag{22}
xt∗=kx0+mxt+σϵ 其中 ϵ∼N(0,I)(22)
将(7)式代入,得到:
x
t
∗
=
k
x
0
+
m
x
t
+
σ
ϵ
=
k
x
0
+
m
(
α
‾
t
x
0
+
1
−
α
‾
t
ϵ
‾
0
)
+
σ
ϵ
=
(
k
+
m
α
‾
t
)
x
0
+
m
1
−
α
‾
t
ϵ
‾
0
+
σ
ϵ
=
(
k
+
m
α
‾
t
)
x
0
+
ϵ
′
(23)
\begin{aligned} x_{t^*} & = kx_0+mx_t + \sigma \epsilon\\ & = kx_0+m(\sqrt{\overline{\alpha}_t} x_{0} + \sqrt{1-\overline{\alpha}_t} \overline \epsilon_{0}) + \sigma \epsilon\\ & = (k+m\sqrt{\overline{\alpha}_t})x_0 + m\sqrt{1-\overline{\alpha}_t} \overline \epsilon_{0} + \sigma \epsilon\\ & = (k+m\sqrt{\overline{\alpha}_t})x_0 + \epsilon' \end{aligned} \tag{23}
xt∗=kx0+mxt+σϵ=kx0+m(αtx0+1−αtϵ0)+σϵ=(k+mαt)x0+m1−αtϵ0+σϵ=(k+mαt)x0+ϵ′(23)
其中, m 1 − α ‾ t ϵ ‾ 0 + σ ϵ m\sqrt{1-\overline{\alpha}_t} \overline \epsilon_{0} + \sigma \epsilon m1−αtϵ0+σϵ 可以合并成 ϵ ′ \epsilon' ϵ′ 是因为高斯分布的可加性,因此有 ϵ ′ ∼ N ( 0 , m 2 ( 1 − α ‾ t ) + σ 2 ) \epsilon' \sim \mathcal{N}(0, m^2(1-\overline{\alpha}_t) + \sigma^2) ϵ′∼N(0,m2(1−αt)+σ2)。
将(7)式变换成 x t ∗ x_{t^*} xt∗ 的形式,然后和(23)式联立:
x t ∗ = α ‾ t ∗ x 0 + 1 − α ‾ t ∗ ϵ ‾ 0 x t ∗ = ( k + m α ‾ t ) x 0 + ϵ ′ \begin{aligned} x_{t^*} &= \sqrt{\overline{\alpha}_{t^*}} x_{0} + \sqrt{1-\overline{\alpha}_{t^*}} \overline \epsilon_{0}\\ x_{t^*} &= (k+m\sqrt{\overline{\alpha}_t})x_0 + \epsilon' \end{aligned} xt∗xt∗=αt∗x0+1−αt∗ϵ0=(k+mαt)x0+ϵ′
对应项系数相等得到:
α
‾
t
∗
=
k
+
m
α
‾
t
1
−
α
‾
t
∗
=
m
2
(
1
−
α
‾
t
)
+
σ
2
(24)
\begin{aligned} \sqrt{\overline{\alpha}_{t^*}} &= k+m\sqrt{\overline{\alpha}_t}\\ 1-\overline{\alpha}_{t^*} &= m^2(1-\overline{\alpha}_t) + \sigma^2 \end{aligned} \tag{24}
αt∗1−αt∗=k+mαt=m2(1−αt)+σ2(24)
因此有:
m
=
1
−
α
‾
t
∗
−
σ
2
1
−
α
‾
t
k
=
α
‾
t
∗
−
1
−
α
‾
t
∗
−
σ
2
1
−
α
‾
t
α
‾
t
(25)
\begin{aligned} m &= \sqrt{\frac{1-\overline{\alpha}_{t^*} - \sigma^2}{1-\overline{\alpha}_t}}\\ k &= \sqrt{\overline{\alpha}_{t^*}} - \sqrt{\frac{1-\overline{\alpha}_{t^*} - \sigma^2}{1-\overline{\alpha}_t}} \sqrt{\overline{\alpha}_t}\\ \end{aligned} \tag{25}
mk=1−αt1−αt∗−σ2=αt∗−1−αt1−αt∗−σ2αt(25)
和 DDPM 逆扩散过程一样,代入(7)式将
x
0
x_0
x0 替换成
x
t
x_t
xt 表示,再将(25)式代入(22)式:
x
t
∗
=
k
x
0
+
m
x
t
+
σ
ϵ
=
(
α
‾
t
∗
−
1
−
α
‾
t
∗
−
σ
2
1
−
α
‾
t
α
‾
t
)
(
x
t
−
1
−
α
‾
t
ϵ
‾
0
α
‾
t
)
+
1
−
α
‾
t
∗
−
σ
2
1
−
α
‾
t
x
t
+
σ
ϵ
=
α
‾
t
∗
α
‾
t
x
t
+
(
1
−
α
‾
t
∗
−
σ
2
−
α
‾
t
∗
(
1
−
α
‾
t
)
α
‾
t
)
ϵ
‾
0
+
σ
ϵ
(26)
\begin{aligned} x_{t^*} &= kx_0+mx_t + \sigma \epsilon\\ &= (\sqrt{\overline{\alpha}_{t^*}} - \sqrt{\frac{1-\overline{\alpha}_{t^*} - \sigma^2}{1-\overline{\alpha}_t}} \sqrt{\overline{\alpha}_t})(\frac{x_t - \sqrt{1-\overline{\alpha}_t} \overline \epsilon_{0}}{\sqrt{\overline{\alpha}_t}}) + \sqrt{\frac{1-\overline{\alpha}_{t^*} - \sigma^2}{1-\overline{\alpha}_t}}x_t + \sigma \epsilon\\ &= \sqrt{\frac{\overline{\alpha}_{t^*}}{\overline{\alpha}_t}}x_t + (\sqrt{1-\overline{\alpha}_{t^*} - \sigma^2}-\sqrt{\frac{\overline{\alpha}_{t^*}(1-\overline{\alpha}_t)}{\overline{\alpha}_t}}) \overline \epsilon_{0} + \sigma \epsilon\\ \end{aligned} \tag{26}
xt∗=kx0+mxt+σϵ=(αt∗−1−αt1−αt∗−σ2αt)(αtxt−1−αtϵ0)+1−αt1−αt∗−σ2xt+σϵ=αtαt∗xt+(1−αt∗−σ2−αtαt∗(1−αt))ϵ0+σϵ(26)
需要注意的是,此处的
σ
\sigma
σ 并非(20)式中的
1
−
α
‾
t
−
1
1
−
α
‾
t
⋅
β
t
\sqrt{\frac{1-\overline{\alpha}_{t-1}}{1-\overline{\alpha}_t} \cdot \beta_t}
1−αt1−αt−1⋅βt,而是(21)式中待定的正态分布标准差。DDIM 将此
σ
\sigma
σ 置为 0,这样整个生成过程就是一个确定性的过程,不再引入随机噪声。在这样的情况下,一旦
x
T
x_T
xT 确定,那么 DDIM 的样本生成就变成了确定的过程。因此,DDIM 的去噪公式可以写成:
x
t
∗
=
α
‾
t
∗
α
‾
t
x
t
+
(
1
−
α
‾
t
∗
−
α
‾
t
∗
(
1
−
α
‾
t
)
α
‾
t
)
ϵ
‾
0
(27)
\begin{aligned} x_{t^*} &= \sqrt{\frac{\overline{\alpha}_{t^*}}{\overline{\alpha}_t}}x_t + (\sqrt{1-\overline{\alpha}_{t^*}}-\sqrt{\frac{\overline{\alpha}_{t^*}(1-\overline{\alpha}_t)}{\overline{\alpha}_t}}) \overline \epsilon_{0}\\ \end{aligned} \tag{27}
xt∗=αtαt∗xt+(1−αt∗−αtαt∗(1−αt))ϵ0(27)
上式中 ϵ ‾ 0 \overline \epsilon_{0} ϵ0 通过 U-Net 进行预测,其他都是已知参数。
需要注意的是,DDIM 原文中将 DDPM 中的 α ‾ t \overline{\alpha}_t αt 都写成了 α t \alpha_t αt,并且使用的是 t − 1 t-1 t−1 而非 t ∗ t^* t∗,即:
x t − 1 = α t − 1 α t x t + ( 1 − α t − 1 − α t − 1 ( 1 − α t ) α t ) ϵ θ ( x t , t ) (28) x_{t-1} = \sqrt{\frac{\alpha_{t-1}}{\alpha_t}}x_t + (\sqrt{1-\alpha_{t-1}}-\sqrt{\frac{\alpha_{t-1}(1-\alpha_t)}{\alpha_t}}) \epsilon_{\theta}(x_t, t) \tag{28} xt−1=αtαt−1xt+(1−αt−1−αtαt−1(1−αt))ϵθ(xt,t)(28)
更多推导过程参见 DDPM与DDIM简洁版总结、DDPM和DDIM公式推导(精简版)。也有从相隔多个迭代步数采样向前推的,参见 一个视频看懂DDIM凭什么加速采样|扩散模型相关。
七. 重建
DDIM 原文还额外讨论了重建和插值这两个问题,这里主要介绍重建 2。
对(28)式变换得到:
x
t
−
1
α
t
−
1
=
x
t
α
t
+
(
1
−
α
t
−
1
α
t
−
1
−
1
−
α
t
α
t
)
ϵ
θ
(
x
t
,
t
)
\frac{x_{t-1}}{\sqrt{\alpha_{t-1}}}=\frac{x_t}{\sqrt{\alpha_t}}+\left(\sqrt{\frac{1-\alpha_{t-1}}{\alpha_{t-1}}}-\sqrt{\frac{1-\alpha_t}{\alpha_t}}\right) \epsilon_\theta\left(x_t, t\right)
αt−1xt−1=αtxt+(αt−11−αt−1−αt1−αt)ϵθ(xt,t)
当
T
T
T 足够大时,上式其实可以看成用欧拉法来求解一个常微分方程 (ordinary differential equation, ODE):
x
t
−
Δ
t
α
t
−
Δ
t
=
x
t
α
t
+
(
1
−
α
t
−
Δ
t
α
t
−
Δ
t
−
1
−
α
t
α
t
)
ϵ
θ
(
x
t
,
t
)
\frac{x_{t-\Delta t}}{\sqrt{\alpha_{t-\Delta t}}}=\frac{x_t}{\sqrt{\alpha_t}}+\left(\sqrt{\frac{1-\alpha_{t-\Delta t}}{\alpha_{t-\Delta t}}}-\sqrt{\frac{1-\alpha_t}{\alpha_t}}\right) \epsilon_\theta\left(x_t, t\right)
αt−Δtxt−Δt=αtxt+(αt−Δt1−αt−Δt−αt1−αt)ϵθ(xt,t)
将
x
x
x 和
α
\alpha
α 都视为
t
t
t 的函数,令
σ
=
1
−
α
/
α
\sigma=\sqrt{1-\alpha} / \sqrt{\alpha}
σ=1−α/α,
x
‾
=
x
/
α
\overline{x}=x / \sqrt{\alpha}
x=x/α,对应 ODE 如下:
d
x
‾
(
t
)
=
ϵ
θ
(
x
‾
(
t
)
σ
2
+
1
,
t
)
d
σ
(
t
)
\mathrm{d} \overline{x}(t)=\epsilon_\theta\left(\frac{\overline{x}(t)}{\sqrt{\sigma^2+1}}, t\right) \mathrm{d} \sigma(t)
dx(t)=ϵθ(σ2+1x(t),t)dσ(t)
因此有:
x
t
+
1
α
t
+
1
=
x
t
α
t
+
(
1
−
α
t
+
1
α
t
+
1
−
1
−
α
t
α
t
)
ϵ
θ
(
x
t
,
t
)
\frac{x_{t+1}}{\sqrt{\alpha_{t+1}}}=\frac{x_t}{\sqrt{\alpha_t}}+\left(\sqrt{\frac{1-\alpha_{t+1}}{\alpha_{t+1}}}-\sqrt{\frac{1-\alpha_t}{\alpha_t}}\right) \epsilon_\theta\left(x_t, t\right)
αt+1xt+1=αtxt+(αt+11−αt+1−αt1−αt)ϵθ(xt,t)
即:
x
t
+
1
=
α
t
+
1
α
t
x
t
+
(
1
−
α
t
+
1
−
α
t
+
1
(
1
−
α
t
)
α
t
)
ϵ
θ
(
x
t
,
t
)
x_{t+1}=\frac{\sqrt{\alpha_{t+1}}}{\sqrt{\alpha_t}}x_t+\left(\sqrt{1-\alpha_{t+1}}-\sqrt{\frac{\alpha_{t+1}(1-\alpha_t)}{\alpha_t}}\right) \epsilon_\theta\left(x_t, t\right)
xt+1=αtαt+1xt+
1−αt+1−αtαt+1(1−αt)
ϵθ(xt,t)
这就是 DDIM 反演的前向公式,常用于重建或者编辑任务中。需要注意的是,这里的 T T T 足够大 的前提条件在大多数情况下无法满足,所以会出现重建或者编辑质量低的情况。
八. 总结
DDIM 和 DDPM 的正向扩散过程完全一样,逆向扩散时可以隔多步进行采样,通过一个确定性映射直接将噪声转换为数据,避免了 DDPM 中的随机性,在减少生成时间步的同时,保持生成图像的高质量。推理过程中的确定性映射表示如下:
x
τ
i
−
1
=
α
ˉ
τ
i
−
1
(
x
τ
i
−
1
−
α
ˉ
τ
i
ϵ
θ
(
x
τ
i
,
τ
i
)
α
ˉ
τ
i
)
+
1
−
α
ˉ
τ
i
−
1
ϵ
θ
(
x
τ
i
,
τ
i
)
\mathbf{x}_{\tau_{i-1}} = \sqrt{\bar{\alpha}_{\tau_{i-1}}} \left( \frac{\mathbf{x}_{\tau_i} - \sqrt{1 - \bar{\alpha}_{\tau_i}} \epsilon_\theta(\mathbf{x}_{\tau_i}, \tau_i)}{\sqrt{\bar{\alpha}_{\tau_i}}} \right) + \sqrt{1 - \bar{\alpha}_{\tau_{i-1}}} \epsilon_\theta(\mathbf{x}_{\tau_i}, \tau_i)
xτi−1=αˉτi−1(αˉτixτi−1−αˉτiϵθ(xτi,τi))+1−αˉτi−1ϵθ(xτi,τi)
其中 τ = { τ 1 , τ 2 , . . . , τ N } \tau = \{ \tau_1, \tau_2, ..., \tau_N \} τ={τ1,τ2,...,τN} 是时间步长序列。通过这个机制,DDIM 可以在生成过程中跳过多个步骤。