扩散模型Diffusion model | DDIM
论文原文:Denoising Diffusion Implicit Models
有关DDPM的解释可以参考我的上一篇博客:扩散模型Diffusion model | DDPM
DDPM的贝叶斯解释
直接根据贝叶斯定理我们有
p
(
x
t
−
1
∣
x
t
)
=
p
(
x
t
∣
x
t
−
1
)
p
(
x
t
−
1
)
p
(
x
t
)
p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t)=\frac{p(\boldsymbol{x}_t|\boldsymbol{x}_{t-1})p(\boldsymbol{x}_{t-1})}{p(\boldsymbol{x}_t)}
p(xt−1∣xt)=p(xt)p(xt∣xt−1)p(xt−1)
但是
p
(
x
t
−
1
)
,
p
(
x
t
)
p(x_{t−1}),p(x_{t})
p(xt−1),p(xt)难以直接计算,因而转向计算
p
(
x
t
−
1
∣
x
t
,
x
0
)
=
p
(
x
t
∣
x
t
−
1
)
p
(
x
t
−
1
∣
x
0
)
p
(
x
t
∣
x
0
)
p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,\boldsymbol{x}_0)=\frac{p(\boldsymbol{x}_t|\boldsymbol{x}_{t-1})p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_0)}{p(\boldsymbol{x}_t|\boldsymbol{x}_0)}
p(xt−1∣xt,x0)=p(xt∣x0)p(xt∣xt−1)p(xt−1∣x0)
代入各自的表达式得到:(本文中与原文定义不同的是:
β
t
=
1
−
α
t
2
\beta_t=\sqrt{1-\alpha_t^2}
βt=1−αt2,原文是
α
t
=
1
−
β
t
\alpha_t=1-\beta_t
αt=1−βt)
p
(
x
t
−
1
∣
x
t
,
x
0
)
=
N
(
x
t
−
1
;
α
t
β
ˉ
t
−
1
2
β
ˉ
t
2
x
t
+
α
ˉ
t
−
1
β
t
2
β
ˉ
t
2
x
0
,
β
ˉ
t
−
1
2
β
t
2
β
ˉ
t
2
I
)
p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,\boldsymbol{x}_0)=\mathcal{N}\left(\boldsymbol{x}_{t-1};\frac{\alpha_t\bar{\beta}_{t-1}^2}{\bar{\beta}_t^2}\boldsymbol{x}_t+\frac{\bar{\alpha}_{t-1}\beta_t^2}{\bar{\beta}_t^2}\boldsymbol{x}_0,\frac{\bar{\beta}_{t-1}^2\beta_t^2}{\bar{\beta}_t^2}\boldsymbol{I}\right)
p(xt−1∣xt,x0)=N(xt−1;βˉt2αtβˉt−12xt+βˉt2αˉt−1βt2x0,βˉt2βˉt−12βt2I)
用 μ ˉ ( x t ) \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t) μˉ(xt)来预估 x 0 \boldsymbol{x}_0 x0损失 ∥ x 0 − μ ˉ ( x t ) ∥ 2 \|\boldsymbol{x}_0-\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)\|^2 ∥x0−μˉ(xt)∥2,就可以消去 p ( x t − 1 ∣ x t , x 0 ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,\boldsymbol{x}_0) p(xt−1∣xt,x0)中的 x 0 \boldsymbol{x}_0 x0,使得它只依赖于 x t \boldsymbol{x}_t xt了。
实际上这也就是个去噪的过程,对应DDPM的Denoising过程。
而由
p
(
x
t
∣
x
0
)
=
N
(
x
t
;
α
ˉ
t
x
0
,
β
ˉ
t
2
I
)
p(\boldsymbol{x}_t|\boldsymbol{x}_0)=\mathcal{N}(\boldsymbol{x}_t;\bar{\alpha}_t\boldsymbol{x}_0,\bar{\beta}_t^2\boldsymbol{I})
p(xt∣x0)=N(xt;αˉtx0,βˉt2I)可以推出
x
0
=
1
α
ˉ
t
(
x
t
−
β
ˉ
t
ε
)
\boldsymbol{x}_0=\frac1{\bar{\alpha}_t}\left(\boldsymbol{x}_t-\bar{\beta}_t\boldsymbol{\varepsilon}\right)
x0=αˉt1(xt−βˉtε),于是我们可以构造
μ
ˉ
(
x
t
)
=
1
α
ˉ
t
(
x
t
−
β
ˉ
t
ϵ
θ
(
x
t
,
t
)
)
\bar{\boldsymbol{\mu}}(x_t) =\frac1{\bar{\alpha}_t}(\boldsymbol{x}_t-\bar{\beta}_t\boldsymbol{\epsilon}_\theta(\boldsymbol{x}_t,t))
μˉ(xt)=αˉt1(xt−βˉtϵθ(xt,t))
代回即可得到DDPM的损失函数。
用 μ ˉ ( x t ) \bar{\boldsymbol{\mu}}(x_t) μˉ(xt)来预估 x 0 x_{0} x0不会太准,它仅仅起到了一个前瞻性的预估作用,然后只用 p ( x t − 1 ∣ x t ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) p(xt−1∣xt)来推进一小步(类似于梯度下降中,找到最速下降方向,但是仅仅下降一小步的预估-修正思想)。
DDIM
DDPM的推导思路为:
x
t
→
m
o
d
e
l
ϵ
θ
(
x
t
,
t
)
→
P
(
x
t
∣
x
0
)
→
P
(
x
0
∣
x
t
,
ϵ
θ
)
x
^
0
(
x
t
,
ϵ
θ
)
→
推导
μ
(
x
t
,
x
^
0
)
,
β
t
→
P
(
x
t
−
1
∣
x
t
,
x
0
)
x
^
t
−
1
x_t\xrightarrow{model}\epsilon_\theta(x_t,t)\xrightarrow{P(x_t|x_0)\to P(x_0|x_t,\epsilon_\theta)}\hat{x}_0(x_t,\epsilon_\theta)\xrightarrow{\text{推导}} \mu ( x _ t , \hat { x }_0),\beta_t\xrightarrow{P(x_{t-1}|x_t,x_0)}\hat{x}_{t-1}
xtmodelϵθ(xt,t)P(xt∣x0)→P(x0∣xt,ϵθ)x^0(xt,ϵθ)推导μ(xt,x^0),βtP(xt−1∣xt,x0)x^t−1
在上述推导中,可以看到
- 损失函数只依赖于 p ( x t ∣ x 0 ) p(\boldsymbol{x}_t|\boldsymbol{x}_0) p(xt∣x0);
- 采样过程只依赖于 p ( x t − 1 ∣ x t ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) p(xt−1∣xt)。
因此由于对马尔可夫假设的依赖,导致DDPM的重建过程需要较多的步长。
但是实际上推理与
p
(
x
t
∣
x
t
−
1
)
p(\boldsymbol{x}_{t}|\boldsymbol{x}_{t-1})
p(xt∣xt−1)好像并没有关系(仅仅是马尔可夫的约束),如果想要加速这个重建过程,可以考虑解除对前向过程
p
(
x
t
∣
x
t
−
1
)
p(\boldsymbol{x}_{t}|\boldsymbol{x}_{t-1})
p(xt∣xt−1)马尔可夫特性的依赖,直接定义分布
p
(
x
t
−
1
∣
x
t
,
x
0
)
p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,\boldsymbol{x}_0)
p(xt−1∣xt,x0),原来的前向过程即可变化为:
q
σ
(
x
1
:
T
∣
x
0
)
=
q
σ
(
x
T
∣
x
0
)
∏
t
=
2
T
q
σ
(
x
t
−
1
∣
x
t
,
x
0
)
q_{\sigma}(\mathbf{x}_{1:T}|\mathbf{x}_{0})=q_{\sigma}(\mathbf{x}_{T}|\mathbf{x}_{0})\prod_{t=2}^{T}q_{\sigma}(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})
qσ(x1:T∣x0)=qσ(xT∣x0)t=2∏Tqσ(xt−1∣xt,x0)
但为了保持前向过程与DDPM等价,需要满足边缘分布条件
∫
p
(
x
t
−
1
∣
x
t
,
x
0
)
p
(
x
t
∣
x
0
)
d
x
t
=
p
(
x
t
−
1
∣
x
0
)
\int p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,\boldsymbol{x}_0)p(\boldsymbol{x}_t|\boldsymbol{x}_0)d\boldsymbol{x}_t=p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_0)
∫p(xt−1∣xt,x0)p(xt∣x0)dxt=p(xt−1∣x0)
使用待定系数法,即可得到(具体参见苏老师的博客:DDIM = 高观点DDPM)
p
(
x
t
−
1
∣
x
t
,
x
0
)
=
N
(
x
t
−
1
;
β
ˉ
t
−
1
2
−
σ
t
2
β
ˉ
t
x
t
+
(
α
ˉ
t
−
1
−
α
ˉ
t
β
t
−
1
2
−
σ
t
2
β
ˉ
t
)
x
0
,
σ
t
2
I
)
p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,\boldsymbol{x}_0)=\mathcal{N}\left(\boldsymbol{x}_{t-1};\frac{\sqrt{\bar{\beta}_{t-1}^2-\sigma_t^2}}{\bar{\beta}_t}\boldsymbol{x}_t+\left(\bar{\alpha}_{t-1}-\frac{\bar{\alpha}_t\sqrt{\beta_{t-1}^2-\sigma_t^2}}{\bar{\beta}_t}\right)\boldsymbol{x}_0,\sigma_t^2\boldsymbol{I}\right)
p(xt−1∣xt,x0)=N
xt−1;βˉtβˉt−12−σt2xt+
αˉt−1−βˉtαˉtβt−12−σt2
x0,σt2I
那么接下来的过程就与DDPM相同了…
p ( x t − 1 ∣ x t ) ≈ p ( x t − 1 ∣ x t , x 0 = μ ˉ ( x t ) ) = N ( x t − 1 ; 1 α t ( x t − ( β ˉ t − α t β ˉ t − 1 2 − σ t 2 ) ϵ θ ( x t , t ) ) , σ t 2 I ) \begin{aligned} p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t)& \approx p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,\boldsymbol{x}_0=\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)) \\ &=\mathcal{N}\left(\boldsymbol{x}_{t-1};\frac1{\alpha_t}\left(\boldsymbol{x}_t-\left(\bar{\beta}_t-\alpha_t\sqrt{\bar{\beta}_{t-1}^2-\sigma_t^2}\right)\boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t,t)\right),\sigma_t^2\boldsymbol{I}\right) \end{aligned} p(xt−1∣xt)≈p(xt−1∣xt,x0=μˉ(xt))=N(xt−1;αt1(xt−(βˉt−αtβˉt−12−σt2)ϵθ(xt,t)),σt2I)
此式带有一个自由参数 σ t \sigma_t σt,和DDPM相比训练过程没有变化,但生成过程却有一个可变动的参数 σ t \sigma_t σt,不同 σ t \sigma_t σt的采样过程会呈现出不同的特点。
而当 σ t = β ˉ t − 1 β t β ˉ t \sigma_{t}=\frac{\bar{\beta}_{t-1}\beta_{t}}{\bar{\beta}_{t}} σt=βˉtβˉt−1βt时,DDIM与DDPM等价;当 σ t = 0 \sigma_{t}=0 σt=0时,此时从 x t x_t xt到 x t − 1 x_{t−1} xt−1是一个确定性变换,也就是论文特指的部分。
DDPM的训练结果实质上包含了它的任意子序列参数的训练结果。
原文推导
如果将条件改回 α t = 1 − β t \alpha_t=1-\beta_t αt=1−βt
则
p
(
x
t
∣
x
0
)
p(\boldsymbol{x}_t|\boldsymbol{x}_0)
p(xt∣x0)可以形式化为:
x
0
=
x
t
−
1
−
α
ˉ
t
ϵ
θ
(
t
)
(
x
t
)
α
ˉ
t
x_0=\frac{\boldsymbol{x}_t-\sqrt{1-\bar{\alpha}_t}\epsilon_\theta^{(t)}\left(\boldsymbol{x}_t\right)}{\sqrt{\bar{\alpha}_t}}
x0=αˉtxt−1−αˉtϵθ(t)(xt)
而反向条件概率
q
σ
(
x
t
−
1
∣
x
t
,
x
0
)
q_\sigma\left(\mathbf{x}_{t-1}\mid\mathbf{x}_t,\mathbf{x}_0\right)
qσ(xt−1∣xt,x0)为:
q
σ
(
x
t
−
1
∣
x
t
,
x
0
)
=
N
(
x
t
−
1
;
α
ˉ
t
−
1
x
0
+
1
−
α
ˉ
t
−
1
−
σ
t
2
x
t
−
α
ˉ
t
x
0
1
−
α
ˉ
t
,
σ
t
2
I
)
q_\sigma\left(\mathbf{x}_{t-1}\mid\mathbf{x}_t,\mathbf{x}_0\right)=\mathcal{N}\left(\mathbf{x}_{t-1};\sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0+\sqrt{1-\bar{\alpha}_{t-1}-\sigma_t^2}\frac{\mathbf{x}_t-\sqrt{\bar{\alpha}_t}\mathbf{x}_0}{\sqrt{1-\bar{\alpha}_t}},\sigma_t^2\mathbf{I}\right)
qσ(xt−1∣xt,x0)=N(xt−1;αˉt−1x0+1−αˉt−1−σt21−αˉtxt−αˉtx0,σt2I)
采样过程即为:
x
t
−
1
=
α
ˉ
t
−
1
(
x
t
−
1
−
α
ˉ
t
α
ˉ
t
ϵ
θ
(
t
)
(
x
t
)
α
ˉ
t
)
⏟
predicied
x
0
+
1
−
α
ˉ
t
−
1
−
σ
t
2
⋅
ϵ
θ
(
t
)
(
x
t
)
⏟
direction pointing to
x
t
+
σ
t
⋅
ϵ
t
⏟
random noise
x_{t-1}=\sqrt{\bar{\alpha}_{t-1}}\underbrace{\left(\frac{x_{t}-\sqrt{1-\bar{\alpha}_{t}}}{\sqrt{\bar{\alpha}_{t}}}\frac{\epsilon_{\theta}^{(t)}(x_{t})}{\sqrt{\bar{\alpha}_{t}}}\right)}_{\text{predicied }x_{0}}+\underbrace{\sqrt{1-\bar{\alpha}_{t-1}-\sigma_{t}^{2}}\cdot\epsilon_{\theta}^{(t)}(x_{t})}_{\text{direction pointing to }x_{t}}+\underbrace{\sigma_{t}\cdot\epsilon_{t}}_{\text{random noise}}
xt−1=αˉt−1predicied x0
(αˉtxt−1−αˉtαˉtϵθ(t)(xt))+direction pointing to xt
1−αˉt−1−σt2⋅ϵθ(t)(xt)+random noise
σt⋅ϵt
而可以将
σ
t
\sigma_t
σt进一步定义为:
σ
t
=
η
(
1
−
α
ˉ
t
−
1
)
/
(
1
−
α
ˉ
t
)
1
−
α
ˉ
t
/
α
ˉ
t
−
1
\sigma_t=\eta\sqrt{(1-\bar{\alpha}_{t-1})/(1-\bar{\alpha}_t)}\sqrt{1-\bar{\alpha}_t/\bar{\alpha}_{t-1}}
σt=η(1−αˉt−1)/(1−αˉt)1−αˉt/αˉt−1
如果 η = 0 \eta=0 η=0,生成过程就没有随机噪音了,是一个确定性的过程,论文将这种情况下的模型称为DDIM(denoising diffusion implicit model);而如果 η = 1 \eta=1 η=1,该前向过程变成了马尔科夫链,模型为DDPM。
重建与插值
对于
η
=
0
\eta=0
η=0时的DDIM,实质上就是将任意正态噪声向量变换为图片的一个确定性变换,此时
x
t
{x}_{t}
xt生成
x
t
−
1
{x}_{t-1}
xt−1的更新公式就变为:
x
t
−
1
=
α
t
−
1
(
x
t
−
1
−
α
t
ϵ
θ
(
x
t
,
t
)
α
t
)
+
1
−
α
t
−
1
⋅
ϵ
θ
(
x
t
,
t
)
\mathbf{x}_{t-1}=\sqrt{\alpha_{t-1}}\Big(\frac{\mathbf{x}_t-\sqrt{1-\alpha_t}\epsilon_\theta(\mathbf{x}_t,t)}{\sqrt{\alpha_t}}\Big)+\sqrt{1-\alpha_{t-1}}\cdot\epsilon_\theta(\mathbf{x}_t,t)
xt−1=αt−1(αtxt−1−αtϵθ(xt,t))+1−αt−1⋅ϵθ(xt,t)
对上式作等价变换可以得到:
x
t
−
1
α
t
−
1
=
x
t
α
t
+
(
1
−
α
t
−
1
α
t
−
1
−
1
−
α
t
α
t
)
ϵ
θ
(
x
t
,
t
)
\frac{\mathbf{x}_{t-1}}{\sqrt{\alpha_{t-1}}}=\frac{\mathbf{x}_{t}}{\sqrt{\alpha_{t}}}+\Big(\sqrt{\frac{1-\alpha_{t-1}}{\alpha_{t-1}}}-\sqrt{\frac{1-\alpha_{t}}{\alpha_{t}}}\Big)\epsilon_{\theta}(\mathbf{x}_{t},t)
αt−1xt−1=αtxt+(αt−11−αt−1−αt1−αt)ϵθ(xt,t)
当
T
T
T足够大,或者说
α
t
\alpha_{t}
αt与
α
t
−
1
\alpha_{t-1}
αt−1足够小时,我们可以将上式视为某个常微分方程ODE的差分形式。
x
t
−
Δ
t
α
t
−
Δ
t
=
x
t
α
t
+
(
1
−
α
t
−
Δ
t
α
t
−
Δ
t
−
1
−
α
t
α
t
)
ϵ
θ
(
x
t
,
t
)
\frac{\mathbf{x}_{t-\Delta t}}{\sqrt{\alpha_{t-\Delta t}}}=\frac{\mathbf{x}_t}{\sqrt{\alpha_t}}+\Big(\sqrt{\frac{1-\alpha_{t-\Delta t}}{\alpha_{t-\Delta t}}}-\sqrt{\frac{1-\alpha_t}{\alpha_t}}\Big)\epsilon_\theta(\mathbf{x}_t,t)
αt−Δtxt−Δt=αtxt+(αt−Δt1−αt−Δt−αt1−αt)ϵθ(xt,t)
这里令
σ
=
1
−
α
/
α
,
x
ˉ
=
x
/
α
\sigma={\sqrt{1-\alpha}}/{\sqrt{\alpha}},{\bar{\mathbf{x}}}=\mathbf{x}/{\sqrt{\alpha}}
σ=1−α/α,xˉ=x/α,它们都是关于
t
t
t的函数,这样对应的ODE就是:
d
x
ˉ
(
t
)
=
ϵ
θ
(
x
ˉ
(
t
)
σ
2
+
1
,
t
)
d
σ
(
t
)
\mathrm{d}\bar{\mathbf{x}}(t)=\epsilon_{\theta}(\frac{\bar{\mathbf{x}}(t)}{\sqrt{\sigma^{2}+1}},t)\mathrm{d}\sigma(t)
dxˉ(t)=ϵθ(σ2+1xˉ(t),t)dσ(t)
那么可以由一个原始图像
x
0
x_0
x0得到对应的随机噪音
x
T
x_T
xT,然后我们再用
x
T
x_T
xT进行生成就可以重建原始图像
x
0
x_0
x0,可以得到较低的重建误差
x
t
+
1
α
t
+
1
=
x
t
α
t
+
(
1
−
α
t
+
1
α
t
+
1
−
1
−
α
t
α
t
)
ϵ
θ
(
x
t
,
t
)
\frac{\mathbf{x}_{t+1}}{\sqrt{\alpha_{t+1}}}=\frac{\mathbf{x}_{t}}{\sqrt{\alpha_{t}}}+\Big(\sqrt{\frac{1-\alpha_{t+1}}{\alpha_{t+1}}}-\sqrt{\frac{1-\alpha_{t}}{\alpha_{t}}}\Big)\epsilon_{\theta}(\mathbf{x}_{t},t)
αt+1xt+1=αtxt+(αt+11−αt+1−αt1−αt)ϵθ(xt,t)
也就是说,将生成过程等同于求解常微分方程后,可以借助常微分方程的数值解法,为生成过程的加速提供更丰富多样的手段。
η
=
0
\eta=0
η=0时,DDIM,所以跟GAN类似,我们可以对这两个随机噪音进行插值生成新的
x
T
x_T
xT,那么将生成融合的图像。这里采用的插值方法是球面线性插值( spherical linear interpolation),参数
α
\alpha
α控制插值系数:
x
T
(
α
)
=
sin
(
(
1
−
α
)
θ
)
sin
(
θ
)
x
T
(
0
)
+
sin
(
α
θ
)
sin
(
θ
)
x
T
(
1
)
θ
=
arccos
(
(
x
T
(
0
)
)
T
x
T
(
1
)
∥
x
T
(
0
)
∥
∥
x
T
(
1
)
∥
)
\mathbf{x}_T^{(\alpha)}=\frac{\sin((1-\alpha)\theta)}{\sin(\theta)}\mathbf{x}_T^{(0)}+\frac{\sin(\alpha\theta)}{\sin(\theta)}\mathbf{x}_T^{(1)}\quad\theta=\arccos\Big(\frac{(\mathbf{x}_T^{(0)})^\mathrm{T}\mathbf{x}_T^{(1)}}{\|\mathbf{x}_T^{(0)}\|\|\mathbf{x}_T^{(1)}\|}\Big)
xT(α)=sin(θ)sin((1−α)θ)xT(0)+sin(θ)sin(αθ)xT(1)θ=arccos(∥xT(0)∥∥xT(1)∥(xT(0))TxT(1))