快速理解扩散概率模型
注意:本文从贝叶斯公式出发理解去噪过程的原理,本文公式推导并不完全,跳过了一些繁琐运算的过程,但足够理解扩散模型的两个过程在做些什么,深入理解数学原理可以看看https://kexue.fm/archives/9119系列博客。
本文内容主要来自:
若有错误,望看官及时指正!
Overview
Process
正向加噪
从一张真实干净的 x 0 x_0 x0逐步加噪到 x T x_T xT,每一步所加入的噪声比重越来越大,直至图像成为完全的高斯噪声。
整个过程是满足马尔可夫链性质(memoryless
),
x
t
x_t
xt只与
x
t
−
1
x_{t-1}
xt−1有关(
t
∈
[
0
,
T
−
1
]
t \in [0, T-1]
t∈[0,T−1],
T
T
T为设定的总扩散步数):
x
t
=
α
t
x
t
−
1
+
1
−
α
t
ϵ
t
x_t = \sqrt{\alpha_t}x_{t-1} + \sqrt{1-\alpha_t}\epsilon_t
xt=αtxt−1+1−αtϵt
记
β
t
=
1
−
α
t
\beta_t=1-\alpha_t
βt=1−αt,这些都是根据
T
T
T预设好的常数,
x
0
→
x
T
x_0\rightarrow x_T
x0→xT加噪过程中,噪声系数开始较小,后来越来越大。
根据上式,可以进行递推:
x
t
=
α
t
x
t
−
1
+
1
−
α
t
ϵ
t
x_t=\sqrt{\alpha_t}x_{t-1}+\sqrt{1-\alpha_t}\epsilon_t
xt=αtxt−1+1−αtϵt
=
α
t
x
t
−
1
+
β
t
ϵ
t
=\sqrt{\alpha_t}x_{t-1}+\sqrt{\beta_t}\epsilon_t
=αtxt−1+βtϵt
=
α
t
(
α
t
−
1
x
t
−
2
+
β
t
−
1
ϵ
t
−
1
)
+
β
t
ϵ
t
=\sqrt{\alpha_t}(\sqrt{\alpha_{t-1}}x_{t-2}+\sqrt{\beta_{t-1}}\epsilon_{t-1})+\sqrt{\beta_{t}}\epsilon_t
=αt(αt−1xt−2+βt−1ϵt−1)+βtϵt
=
∏
k
=
1
t
α
k
x
0
+
β
1
ϵ
1
∏
k
=
2
t
α
k
+
β
2
ϵ
2
∏
k
=
3
t
α
k
+
.
.
.
+
β
t
−
1
ϵ
t
−
1
∏
k
=
t
t
α
k
+
β
t
ϵ
t
=\sqrt{\prod_{k=1}^{t}{\alpha_k}}x_0+\sqrt{\beta_1}\epsilon_1\prod_{k=2}^{t}\sqrt{\alpha_{k}}+\sqrt{\beta_2}\epsilon_2\prod_{k=3}^{t}\sqrt{\alpha_{k}}+...+\sqrt{\beta_{t-1}}\epsilon_{t-1}\prod_{k=t}^{t}\sqrt{\alpha_{k}}+\sqrt{\beta_{t}}\epsilon_t
=k=1∏tαkx0+β1ϵ1k=2∏tαk+β2ϵ2k=3∏tαk+...+βt−1ϵt−1k=t∏tαk+βtϵt
其中,每次加入的噪声
ϵ
t
\epsilon_t
ϵt都是服从标准正态分布
N
(
0
,
1
)
\mathcal N(0,1)
N(0,1),所以上式中带有噪声的每一项都可写成均值为0
,标准差不同的正态分布:
N
(
0
,
β
?
−
1
∏
k
=
?
t
α
k
I
)
\mathcal{N}(0, \sqrt{\beta_{?-1}}\prod_{k=?}^{t}\sqrt{\alpha_k}I)
N(0,β?−1k=?∏tαkI)
累乘
∏
k
t
x
\prod_{k}^t x
∏ktx用
x
ˉ
k
\bar x_k
xˉk代替表示;
多个正态分布相加或减,方差体现为相加:
x
t
=
α
ˉ
t
x
0
+
N
(
0
,
α
ˉ
2
β
1
)
+
N
(
0
,
α
ˉ
3
β
2
)
+
N
(
0
,
α
ˉ
4
β
3
)
+
.
.
.
+
N
(
0
,
α
t
β
t
−
1
)
+
N
(
0
,
β
t
)
x_t=\sqrt{\bar{\alpha}_t}x_0+\mathcal{N}(0, \sqrt{\bar{\alpha}_{2}\beta_1})+\mathcal{N}(0, \sqrt{\bar{\alpha}_{3}\beta_2})+\mathcal{N}(0, \sqrt{\bar{\alpha}_{4}\beta_3})+...+\mathcal{N}(0, \sqrt{{\alpha}_{t}\beta_{t-1}})+\mathcal{N}(0, \sqrt{\beta_t})
xt=αˉtx0+N(0,αˉ2β1)+N(0,αˉ3β2)+N(0,αˉ4β3)+...+N(0,αtβt−1)+N(0,βt)
带入
β
t
=
1
−
α
t
\beta_t=1-\alpha_t
βt=1−αt,得到:
x
t
=
α
ˉ
t
x
0
+
1
−
α
ˉ
t
ϵ
t
x_t = \sqrt{\bar{\alpha}_t}x_0+\sqrt{1-\bar{\alpha}_t}\epsilon_t
xt=αˉtx0+1−αˉtϵt
=
N
(
α
ˉ
t
x
0
,
1
−
α
ˉ
t
)
=\mathcal{N}(\sqrt{\bar{\alpha}_t}x_0,1-\bar{\alpha}_t)
=N(αˉtx0,1−αˉt)
所以通过正向过程,我们得到了任意时刻图像
x
t
x_t
xt关于原始图像
x
0
x_0
x0的表达式,其中
α
ˉ
t
\sqrt{\bar{\alpha}_t}
αˉt和
α
ˉ
t
\bar{\alpha}_t
αˉt均为常数。
逆向去噪
我们的目标是从
x
T
x_T
xT得到一张尽可能真实的图像
x
0
x_0
x0,但也不能一蹴而就,需要一步一步从
x
t
→
x
t
−
1
x_t\rightarrow x_{t-1}
xt→xt−1,由贝叶斯公式可得:
p
(
x
t
−
1
∣
x
t
)
=
p
(
x
t
∣
x
t
−
1
)
p
(
x
t
−
1
)
p
(
x
t
)
p(x_{t-1}|x_t) = p(x_t|x_{t-1})\frac{p(x_{t-1})}{p(x_t)}
p(xt−1∣xt)=p(xt∣xt−1)p(xt)p(xt−1)
发现如果不给定
x
0
x_0
x0,
x
t
,
x
t
−
1
x_t, x_{t-1}
xt,xt−1这些都没有意义,无法计算,由于马尔科夫链的性质,我们可以加上条件
x
0
x_0
x0:
p
(
x
t
−
1
∣
x
t
,
x
0
)
=
p
(
x
t
∣
x
t
−
1
,
x
0
)
p
(
x
t
−
1
∣
x
0
)
p
(
x
t
∣
x
0
)
p(x_{t-1}|x_t,x_0) = p(x_t|x_{t-1},x_0)\frac{p(x_{t-1}|x_0)}{p(x_t|x_0)}
p(xt−1∣xt,x0)=p(xt∣xt−1,x0)p(xt∣x0)p(xt−1∣x0)
由正向加噪过程得到的公式,我们可将等式右边的三项均表示成正态分布的形式:
p
(
x
t
−
1
∣
x
t
)
=
N
(
α
t
x
t
−
1
,
1
−
α
t
)
p(x_{t-1}|x_t)=\mathcal{N}(\sqrt{\alpha_t}x_{t-1},1-\alpha_t)
p(xt−1∣xt)=N(αtxt−1,1−αt)
p
(
x
t
−
1
∣
x
0
)
=
N
(
x
t
−
1
;
α
ˉ
t
−
1
x
0
,
1
−
α
ˉ
t
−
1
)
p(x_{t-1}|x_0)=\mathcal{N}(x_{t-1};\sqrt{\bar{\alpha}_{t-1}}x_0,1-\bar{\alpha}_{t-1})
p(xt−1∣x0)=N(xt−1;αˉt−1x0,1−αˉt−1)
p
(
x
t
∣
x
0
)
=
N
(
x
t
;
α
ˉ
t
x
0
,
1
−
α
ˉ
t
)
p(x_{t}|x_0)=\mathcal{N}(x_{t};\sqrt{\bar{\alpha}_{t}}x_0,1-\bar{\alpha}_{t})
p(xt∣x0)=N(xt;αˉtx0,1−αˉt)
这样一来呢,根据正态分布的计算式:
N
(
μ
,
σ
2
)
=
1
2
π
σ
e
−
(
x
−
μ
)
2
2
σ
2
\mathcal{N}(\mu,\sigma^2) = \frac{1}{\sqrt{2\pi}\sigma}e^{-\frac{(x-\mu)^2}{2\sigma^2}}
N(μ,σ2)=2πσ1e−2σ2(x−μ)2
我们把等式右边的正态分布都替换成上面的形式,经过一段复杂的计算后,将贝叶斯等式右边的三项进行化简并通过配方表示成正态分布的形式:
p
(
x
t
−
1
∣
x
t
,
x
0
)
∝
e
−
1
2
[
(
α
t
β
t
+
1
1
−
α
ˉ
t
−
1
)
x
t
−
1
2
−
(
2
α
t
β
t
x
t
+
2
α
ˉ
t
−
1
1
−
α
ˉ
t
−
1
x
0
)
x
t
−
1
+
C
(
x
t
,
x
0
)
]
p(x_{t-1}|x_t,x_0) \propto e^{-\frac{1}{2}[(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}})x_{t-1}^2-(\frac{2\sqrt{\alpha_t}}{\beta_t}x_t+\frac{2\sqrt{\bar{\alpha}_{t-1}}}{1-\bar{\alpha}_{t-1}}x_0)x_{t-1}+C(x_t, x_0)]}
p(xt−1∣xt,x0)∝e−21[(βtαt+1−αˉt−11)xt−12−(βt2αtxt+1−αˉt−12αˉt−1x0)xt−1+C(xt,x0)]
从而上述分布的均值可表示成:
μ
~
t
(
x
t
,
x
0
)
=
α
t
(
1
−
α
ˉ
t
−
1
)
1
−
α
ˉ
t
x
t
+
α
ˉ
t
−
1
β
t
1
−
α
ˉ
t
x
0
\tilde\mu_t(x_t,x_0)=\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_t}x_t + \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}x_0
μ~t(xt,x0)=1−αˉtαt(1−αˉt−1)xt+1−αˉtαˉt−1βtx0
可以看到,正态分布
p
(
x
t
−
1
∣
x
t
,
x
0
)
p(x_{t-1}|x_t,x_0)
p(xt−1∣xt,x0)的均值
μ
~
t
\tilde\mu_t
μ~t只与
x
t
,
x
0
x_t, x_0
xt,x0有关,但我们所求的就是
x
0
x_0
x0,所以显然
x
0
x_0
x0无法使用,这里用到了预估-修正
的思想,通过正向过程得到的
x
t
x_t
xt关于
x
0
x_0
x0的表达式得到
x
0
x_0
x0的表达式:
x
0
=
x
t
−
1
−
α
ˉ
t
ϵ
t
α
ˉ
t
x_0=\frac{x_t-\sqrt{1-\bar{\alpha}_t\epsilon_t}}{\sqrt{\bar{\alpha}_t}}
x0=αˉtxt−1−αˉtϵt
带入可得到均值:
μ
~
t
=
1
α
t
(
x
t
−
β
t
1
−
α
ˉ
t
ϵ
t
)
\tilde\mu_t=\frac{1}{\sqrt{\alpha_t}}(x_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\epsilon_t)
μ~t=αt1(xt−1−αˉtβtϵt)
所以此时唯一未知量便是此时刻的噪声
ϵ
t
\epsilon_t
ϵt,扩散模型做的就是这个事情,通过一个UNet
结构模型来得到每个时刻的噪声预测值
ϵ
~
t
\tilde\epsilon_t
ϵ~t。
方差在很多我看的几篇工作中都是设为固定值(具体原理没有深入看),所以只要能够预测出此时噪声 ϵ ~ t \tilde\epsilon_t ϵ~t,我们便得到了所求的分布 p ( x t − 1 ∣ x t , x 0 ) p(x_{t-1}|x_t,x_0) p(xt−1∣xt,x0)的均值,也便能够实现逐步 x t → x t − 1 x_t\rightarrow x_{t-1} xt→xt−1逆推过程。
训练&推理
训练过程是包括正向加噪和反向去噪的,在正向过程中加入的噪声 ϵ t \epsilon_t ϵt将作为真实值,与反向过程的预测值 ϵ ^ t \hat\epsilon_t ϵ^t进行损失计算;
推理过程就不涉及正向过程了,直接从噪声开始去噪。