写在前面的一些话
因为自己项目需要,以及总是听说扩散模型,所以自己去b站看了视频,博客主要建立于迪哥视频,尽量写的通俗易懂,致力于高效省时的帮助大家搞明白扩散模型的原理,让小白也能读懂这篇论文
注意!!文章可能涉及比较多公式,但不要害怕!!结合我的说明,看懂没问题的,一步步来!不要着急,不要跳步!如果有错误,欢迎指正!有什么问题欢迎在评论区讨论!
简述
最近经常听说扩散模型,甚至可以打败GAN。回顾GAN,我们需要同时训练生成器和判别器,可能会难以收敛以及学习到一些我们不想要的特征。而diffusion model做的事情就是用了一种更简单的方式来解释生成模型应该怎么学习和生成。diffusion model火起来是因为DALLE 2的出现(也是openai的,跟chatgpt出自一个公司),实现文字转图片,能得到非常惊艳的效果(如下图,生成一个牛油果形状的沙发),可以自行搜索一下他们的网站
整个diffusion model可以分为两部分,一个是前向扩散过程,另一个是逆扩散过程,通俗理解为:前向扩散过程不停的往图片上加服从高斯分布的噪声,加到使图片变得“面目全非”(下图从右到左),逆扩散过程就是不停的减噪声然后复原成图片(从左到右)
在原论文中,扩散过程需要进行2000次加噪声的步骤,实际操作中大约200-500次。在扩散过程中,每次往图片上加的噪声就是逆过程的标签,接下来我会分别解释前向扩散过程以及逆扩散过程
前向扩散过程 forward diffusion
前面说到,扩散过程简单来说就是不停的往图片里加噪声,把图片加的面目全非。那怎么加,加多少呢?论文中给出核心公式:
x
t
=
α
t
x
t
−
1
+
1
−
α
t
z
1
(
1
)
x_t=\sqrt{\alpha_t} x_{t-1}+\sqrt{1-\alpha_t} z_1 \quad(1)
xt=αtxt−1+1−αtz1(1)
这个公式怎么来的呢?别急,我们一步步来看这个公式
首先,值得一提的是,整个扩散模型是符合马尔可夫定理的,也就是说t时刻的分布只与t-1时刻有关,所以为什么公式里只出现了 x t − 1 x_{t-1} xt−1而没有 x t − 2 x_{t-2} xt−2, x t − 3 x_{t-3} xt−3, x t − 4 x_{t-4} xt−4…
其次, α t \alpha_t αt是一个经验常量,且 α t \alpha_t αt会随着t的增大而减小,这是实验前决定的; z 1 z_1 z1 (包括文章后面出现的 z 2 z_2 z2, z 3 z_3 z3…)都是服从标准高斯分布的噪声~N(0,I) 。由此,我们可以将这个公式理解为一部分的 x t − 1 x_{t-1} xt−1加上了一部分的 z 1 z_{1} z1,也就是说, x t x_{t} xt等于前一时刻的分布 x t − 1 x_{t-1} xt−1和标准高斯分布 z 1 z_{1} z1的权重和,而这个权重由 α t \alpha_t αt决定。因为随着t的增大 α t \alpha_t αt会减小,所以 x t − 1 x_{t-1} xt−1的权重会越来越小, z 1 z_{1} z1的权重会越来越大。因此随着t的增大,噪声占比越来越大,前一时刻的分布占比越来越小。
好了,到这里我们搞懂了其中一个核心公式。但有一个问题,假如我加噪声加了1000次,我要是想知道第一千次的分布,难道要从第一步开始一步步往后推吗,知道了
x
0
x_0
x0我才能知道
x
1
x_1
x1,知道
x
1
x_1
x1我才能知道
x
2
x_2
x2?这也太慢了吧。因此论文又给了我们另一个公式:
x
t
=
α
ˉ
t
x
0
+
1
−
α
ˉ
t
z
t
,
α
ˉ
t
=
α
t
α
t
−
1
α
t
−
2
.
.
.
(
2
)
x_t=\sqrt{\bar{\alpha}_t} x_0+\sqrt{1-\bar{\alpha}_t} {z}_t,\bar{\alpha}_t=\alpha_t \alpha_{t-1} \alpha_{t-2}... \quad(2)
xt=αˉtx0+1−αˉtzt,αˉt=αtαt−1αt−2...(2)
这又是怎么来的呢?接下来慢慢解释。
让我们先根据公式(1)写出
x
t
−
1
x_{t-1}
xt−1的公式 (把(1)中的t换成t-1就行了);
x
t
−
1
=
α
t
−
1
x
t
−
2
+
1
−
α
t
−
1
z
2
(
3
)
x_{t-1}=\sqrt{\alpha_{t-1}} x_{t-2}+\sqrt{1-\alpha_{t-1}} z_2 \quad(3)
xt−1=αt−1xt−2+1−αt−1z2(3)
再直接把(3)中的
x
t
−
1
x_{t-1}
xt−1带入到(1):
x
t
=
α
t
(
α
t
−
1
x
t
−
2
+
1
−
α
t
−
1
z
2
)
+
1
−
α
t
z
1
(
4
)
x_t=\sqrt{\alpha_t}\left(\sqrt{\alpha_{t-1}} x_{t-2}+\sqrt{1-\alpha_{t-1}} z_2\right)+\sqrt{1-\alpha_t} z_1 \quad(4)
xt=αt(αt−1xt−2+1−αt−1z2)+1−αtz1(4)
把
α
t
\sqrt{\alpha_t}
αt乘进去,括号移一下:
x
t
=
α
t
α
t
−
1
x
t
−
2
+
(
α
t
(
1
−
α
t
−
1
)
z
2
+
1
−
α
t
z
1
)
(
5
)
x_t=\sqrt{\alpha_t\alpha_{t-1}} x_{t-2}+\left(\sqrt{\alpha_t(1-\alpha_{t-1})} z_2+\sqrt{1-\alpha_t} z_1\right) \quad(5)
xt=αtαt−1xt−2+(αt(1−αt−1)z2+1−αtz1)(5)
到这里应该没什么难点,只是简单的代入。我们仔细观察一下公式(5),发现括号内是两个高斯分布相加(记住:
z
1
,
z
2
,
z
3
.
.
.
z_1,z_2,z_3...
z1,z2,z3...都是服从标准的高斯分布),我们知道两个高斯分布相加还是高斯分布,具体推导可以参考另一个博主的博客。那么括号里加出来新的高斯分布具体是什么呢?
如果我们把公式(5)中的
α
t
(
1
−
α
t
−
1
)
z
2
\sqrt{\alpha_t(1-\alpha_{t-1})} z_2
αt(1−αt−1)z2看作一整个分布,那:
α
t
(
1
−
α
t
−
1
)
z
2
∼
N
(
0
,
α
t
(
1
−
α
t
−
1
)
)
(
6
)
\sqrt{\alpha_t(1-\alpha_{t-1})} z_2 \sim \mathcal{N}(0, {\alpha_t(1-\alpha_{t-1})})\quad(6)
αt(1−αt−1)z2∼N(0,αt(1−αt−1))(6)
同理对于公式(5)中的
1
−
α
t
z
1
\sqrt{1-\alpha_t} z_1
1−αtz1:
1
−
α
t
z
1
∼
N
(
0
,
1
−
α
t
)
(
7
)
\sqrt{1-\alpha_t} z_1 \sim \mathcal{N}(0, {1-\alpha_t})\quad(7)
1−αtz1∼N(0,1−αt)(7)
插一句,这是因为,在一个高斯分布前面乘上一个系数相当于改变它的标准差,给一个高斯分布加上或减去某个数相当于改变它的均值
ok,那根据博客里推导的公式,两个高斯分布相加后新的高斯分布应为:
N
(
0
,
σ
1
2
I
)
+
N
(
0
,
σ
2
2
I
)
∼
N
(
0
,
(
σ
1
2
+
σ
2
2
)
I
)
(
8
)
\mathcal{N}\left(0, \sigma_1^2 \mathbf{I}\right)+\mathcal{N}\left(0, \sigma_2^2 \mathbf{I}\right) \sim \mathcal{N}\left(0,\left(\sigma_1^2+\sigma_2^2\right) \mathbf{I}\right) \quad(8)
N(0,σ12I)+N(0,σ22I)∼N(0,(σ12+σ22)I)(8)
我们把(6)(7)代入(8)一下,可得到新的分布具体为 (就是把(6)(7)的标准差加起来):
∼
N
(
0
,
[
(
α
t
(
1
−
α
t
−
1
)
)
+
(
1
−
α
t
)
]
I
)
=
N
(
0
,
(
α
t
−
α
t
α
t
−
1
+
1
−
α
t
)
I
)
=
N
(
0
,
(
1
−
α
t
α
t
−
1
)
I
)
(
9
)
\sim \mathcal{N}\left(0,\left[(\alpha_t(1-\alpha_{t-1}))+(1-\alpha_t)\right] \mathbf{I}\right) = \mathcal{N}\left(0,\left(\alpha_t-\alpha_t\alpha_{t-1}+1-\alpha_t\right) \mathbf{I}\right) =\mathcal{N}\left(0,\left(1-\alpha_t\alpha_{t-1}\right) \mathbf{I} \right) \quad(9)
∼N(0,[(αt(1−αt−1))+(1−αt)]I)=N(0,(αt−αtαt−1+1−αt)I)=N(0,(1−αtαt−1)I)(9)
为了保持一致,我们就写成:
N
(
0
,
(
1
−
α
t
α
t
−
1
)
I
)
=
1
−
α
t
α
t
−
1
z
ˉ
2
,
z
ˉ
2
∼
N
(
0
,
I
)
\mathcal{N}\left(0,\left(1-\alpha_t\alpha_{t-1}\right) \mathbf{I} \right) =\sqrt{1-\alpha_t \alpha_{t-1}} \bar{z}_2, \quad \bar{z}_2 \sim \mathcal{N}\left(0, \mathbf{I}\right)
N(0,(1−αtαt−1)I)=1−αtαt−1zˉ2,zˉ2∼N(0,I)
那公式(5)就可以改写成 (只用把括号里的改成我们新推导出来的高斯分布就可以了):
x
t
=
α
t
α
t
−
1
x
t
−
2
+
1
−
α
t
α
t
−
1
z
ˉ
2
(
10
)
x_t=\sqrt{\alpha_t\alpha_{t-1}} x_{t-2}+\sqrt{1-\alpha_t \alpha_{t-1}} \bar{z}_2 \quad(10)
xt=αtαt−1xt−2+1−αtαt−1zˉ2(10)
如果没懂就多看几遍公式,没有难点,只是代入可能会比较绕。如果懂了就接着往下面看
我们来比较一下公式(1)和(10),我再把他们写出来:
x t = α t x t − 1 + 1 − α t z 1 ( 1 ) x_t=\sqrt{\alpha_t} x_{t-1}+\sqrt{1-\alpha_t} z_1 \quad(1) xt=αtxt−1+1−αtz1(1)
x
t
=
α
t
α
t
−
1
x
t
−
2
+
1
−
α
t
α
t
−
1
z
ˉ
2
(
10
)
x_t=\sqrt{\alpha_t\alpha_{t-1}} x_{t-2}+\sqrt{1-\alpha_t \alpha_{t-1}} \bar{z}_2 \quad(10)
xt=αtαt−1xt−2+1−αtαt−1zˉ2(10)
你发现了什么?是不是发现这个公式似乎可以类推?每往前推一项,只需要在权重的根号下多乘一项就可以了?那你应该能猜到如何用
x
0
x_0
x0来表示
x
t
x_t
xt了,就是公式(2):
x
t
=
α
ˉ
t
x
0
+
1
−
α
ˉ
t
z
t
,
α
ˉ
t
=
α
t
α
t
−
1
α
t
−
2
.
.
.
(
2
)
x_t=\sqrt{\bar{\alpha}_t} x_0+\sqrt{1-\bar{\alpha}_t} {z}_t,\bar{\alpha}_t=\alpha_t \alpha_{t-1} \alpha_{t-2}... \quad(2)
xt=αˉtx0+1−αˉtzt,αˉt=αtαt−1αt−2...(2)
OK,到这里我们就推导出了扩散过程的一个重要的核心公式,有了公式(2),我只需要知道起始的 x 0 x_0 x0以及现在到了第几步,也就是t,我就可以直接算出 x t x_t xt是多少。扩散过程到这里就结束了,如果没懂的就回顾几遍,接下来要开始说逆扩散过程咯。
逆扩散过程 reverse diffusion
介绍完扩散过程,我们现在来说复原过程。再把这个图拿出来看。我们现在知道怎么算
q
(
x
t
∣
x
t
−
1
)
q(x_t|x_{t-1})
q(xt∣xt−1),也就是扩散过程,但不知道怎么算
p
θ
(
x
t
−
1
∣
x
t
)
p_\theta(x_{t-1}|x_t)
pθ(xt−1∣xt)。括号里的两个变量换了位置,这使我们很容易就联想到贝叶斯公式。最常用的贝叶斯公式如下:
q
(
x
t
−
1
∣
x
t
)
=
q
(
x
t
∣
x
t
−
1
)
q
(
x
t
−
1
)
q
(
x
t
)
(
11
)
q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)=q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right) \frac{q\left(\mathbf{x}_{t-1} \right)}{q\left(\mathbf{x}_t \right)} \quad(11)
q(xt−1∣xt)=q(xt∣xt−1)q(xt)q(xt−1)(11)
等式两边同时引入
x
0
x_0
x0这个变量:
q
(
x
t
−
1
∣
x
t
,
x
0
)
=
q
(
x
t
∣
x
t
−
1
,
x
0
)
q
(
x
t
−
1
∣
x
0
)
q
(
x
t
∣
x
0
)
(
12
)
q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)=q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}, \mathbf{x}_0\right) \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)} \quad(12)
q(xt−1∣xt,x0)=q(xt∣xt−1,x0)q(xt∣x0)q(xt−1∣x0)(12)
公式(12)详细证明可以看这里,我自己推的,不想看可以略过,不会影响后面理解:
证明用到了条件概率公式:
q
(
A
∣
B
)
=
q
(
A
,
B
)
q
(
B
)
q(A|B) = \frac{q(A,B)}{q(B)}
q(A∣B)=q(B)q(A,B)
由此可证:
q
(
x
t
∣
x
t
−
1
,
x
0
)
q
(
x
t
−
1
∣
x
0
)
q
(
x
t
∣
x
0
)
=
q
(
x
t
,
x
t
−
1
,
x
0
)
q
(
x
t
−
1
,
x
0
)
q
(
x
t
−
1
,
x
0
)
q
(
x
0
)
q
(
x
t
,
x
0
)
q
(
x
0
)
=
q
(
x
t
,
x
t
−
1
,
x
0
)
q
(
x
t
,
x
0
)
=
q
(
x
t
−
1
∣
x
t
,
x
0
)
q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}, \mathbf{x}_0\right) \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)} = \frac{q(\mathbf{x}_t,\mathbf{x}_{t-1},\mathbf{x}_0)}{q(\mathbf{x}_{t-1},\mathbf{x}_0)} \frac{ \frac{q(\mathbf{x}_{t-1},\mathbf{x}_0)}{q(\mathbf{x}_0)}}{ \frac{q(\mathbf{x}_{t},\mathbf{x}_0)}{q(\mathbf{x}_0)}} = \frac{q(\mathbf{x}_t,\mathbf{x}_{t-1},\mathbf{x}_0)}{q(\mathbf{x}_{t},\mathbf{x}_0)} = q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)
q(xt∣xt−1,x0)q(xt∣x0)q(xt−1∣x0)=q(xt−1,x0)q(xt,xt−1,x0)q(x0)q(xt,x0)q(x0)q(xt−1,x0)=q(xt,x0)q(xt,xt−1,x0)=q(xt−1∣xt,x0)
公式(12)右边总共由三项组成:
- q ( x t ∣ x t − 1 , x 0 ) q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}, \mathbf{x}_0\right) q(xt∣xt−1,x0)
- q ( x t ∣ x 0 ) q\left(\mathbf{x}_t \mid \mathbf{x}_0\right) q(xt∣x0)
- q ( x t − 1 ∣ x 0 ) q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right) q(xt−1∣x0)
我们一项一项来看:
- q ( x t ∣ x t − 1 , x 0 ) q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}, \mathbf{x}_0\right) q(xt∣xt−1,x0):这个式子表示已知 x t − 1 x_{t-1} xt−1和 x 0 x_0 x0去计算 x t x_t xt,是不是觉得很眼熟?对,就是我们在一开始介绍前向扩散过程的时候给出的公式,知道前一项算后一项。也就是公式(1)所表示的: x t = α t x t − 1 + 1 − α t z x_t=\sqrt{\alpha_t} x_{t-1}+\sqrt{1-\alpha_t} z xt=αtxt−1+1−αtz。在这里似乎已知 x 0 x_0 x0看起来没有存在的必要,但这是为了接下来计算后面的两项。
- q ( x t ∣ x 0 ) q\left(\mathbf{x}_{t} \mid \mathbf{x}_0\right) q(xt∣x0):理解一下这个式子表示的意思,就是已知 x 0 x_0 x0去求 x t x_{t} xt,听起来好像也很熟悉。这不就是我们之前在前向扩散过程中推导出来的公式吗?若已知 x 0 x_0 x0和时间t,我们可以根据公式(2)直接求得 x t x_t xt: x t = α ˉ t x 0 + 1 − α ˉ t z , α ˉ t = α t α t − 1 α t − 2 . . . x_{t}=\sqrt{\bar{\alpha}_t} x_0+\sqrt{1-\bar{\alpha}_t} {z},\bar{\alpha}_t=\alpha_t \alpha_{t-1} \alpha_{t-2}... xt=αˉtx0+1−αˉtz,αˉt=αtαt−1αt−2...
- q ( x t − 1 ∣ x 0 ) q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right) q(xt−1∣x0):同理,已知 x 0 x_0 x0去求 x t − 1 x_{t-1} xt−1,只需要把公式(2)中的t替换成t-1就行了: x t − 1 = α ˉ t − 1 x 0 + 1 − α ˉ t − 1 z x_{t-1}=\sqrt{\bar{\alpha}_{t-1}} x_0+\sqrt{1-\bar{\alpha}_{t-1}} {z} xt−1=αˉt−1x0+1−αˉt−1z
总结,公式(12)右边三项:
- q ( x t ∣ x t − 1 , x 0 ) = a t x t − 1 + 1 − α t z ∼ N ( a t x t − 1 , 1 − α t ) ( 13 ) q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}, \mathbf{x}_0\right) =\sqrt{a_t} x_{t-1}+\sqrt{1-\alpha_t} z \quad \sim \mathcal{N}\left(\sqrt{a_t} x_{t-1}, 1-\alpha_t\right) \quad(13) q(xt∣xt−1,x0)=atxt−1+1−αtz∼N(atxt−1,1−αt)(13)
- q ( x t ∣ x 0 ) = α ˉ t x 0 + 1 − α ˉ t z ∼ N ( α ˉ t x 0 , 1 − α ˉ t ) ( 14 ) q\left(\mathbf{x}_t \mid \mathbf{x}_0\right) = \sqrt{\bar{\alpha}_t} x_0+\sqrt{1-\bar{\alpha}_t} z \quad \sim \mathcal{N}\left(\sqrt{{\bar{\alpha}}_t} x_0, 1-\bar{\alpha}_t\right) \quad(14) q(xt∣x0)=αˉtx0+1−αˉtz∼N(αˉtx0,1−αˉt)(14)
- q ( x t − 1 ∣ x 0 ) = α ˉ t − 1 x 0 + 1 − a ˉ t − 1 z ∼ N ( α ˉ t − 1 x 0 , 1 − a ˉ t − 1 ) ( 15 ) q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right) = \sqrt{\bar{\alpha}_{t-1}} x_0+\sqrt{1-\bar{a}_{t-1}} z \quad \sim \mathcal{N}\left( \sqrt{\bar{\alpha}_{t-1}} x_0, 1-\bar{a}_{t-1}\right) \quad(15) q(xt−1∣x0)=αˉt−1x0+1−aˉt−1z∼N(αˉt−1x0,1−aˉt−1)(15)
Again,在一个高斯分布前面乘上一个系数相当于改变它的标准差,给一个高斯分布加上或减去某个数相当于改变它的均值
然后我们发现了什么?公式(12)的右边三项全都是高斯分布!!! 这意味着我们又可以进行一波操作来化简等式的右边
我们再来看一眼公式(12):
q
(
x
t
−
1
∣
x
t
,
x
0
)
=
q
(
x
t
∣
x
t
−
1
,
x
0
)
q
(
x
t
−
1
∣
x
0
)
q
(
x
t
∣
x
0
)
(
12
)
q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)=q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}, \mathbf{x}_0\right) \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)} \quad(12)
q(xt−1∣xt,x0)=q(xt∣xt−1,x0)q(xt∣x0)q(xt−1∣x0)(12)
发现等式右半边就是一个高斯分布乘上另一个高斯分布再除以一个高斯分布。我们知道高斯分布长这个样子:
f
(
x
)
=
1
2
π
σ
exp
(
−
(
x
−
μ
)
2
2
σ
2
)
f(x)=\frac{1}{\sqrt{2 \pi} \sigma} \exp \left(-\frac{(x-\mu)^2}{2 \sigma^2}\right)
f(x)=2πσ1exp(−2σ2(x−μ)2)
所以可想而知,对于指数部分,两项相乘等于指数相加,相除等于指数相减。
因此等式(12)的右半边(这一步对照着公式(12)和(13)~(15)看就明白了,除的前面就是负号,乘的前面就是加号,然后把均值和标准差代入):
∝
exp
(
−
1
2
(
(
x
t
−
α
t
x
t
−
1
)
2
1
−
α
t
+
(
x
t
−
1
−
α
ˉ
t
−
1
x
0
)
2
1
−
α
ˉ
t
−
1
−
(
x
t
−
α
ˉ
t
x
0
)
2
1
−
α
ˉ
t
)
)
\propto \exp \left(-\frac{1}{2}\left(\frac{\left(\mathbf{x}_t-\sqrt{\alpha_t} \mathbf{x}_{t-1}\right)^2}{1-\alpha_t}+\frac{\left(\mathbf{x}_{t-1}-\sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_0\right)^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t} \mathbf{x}_0\right)^2}{1-\bar{\alpha}_t}\right)\right)
∝exp(−21(1−αt(xt−αtxt−1)2+1−αˉt−1(xt−1−αˉt−1x0)2−1−αˉt(xt−αˉtx0)2))
接下来繁琐的展开和合并同类项,可以选择跳过,只需要知道我们得到了公式(16) (为了简写:
β
t
=
1
−
α
t
\beta_t = 1 - \alpha_t
βt=1−αt)
=
exp
(
−
1
2
(
x
t
2
−
2
α
t
x
t
x
t
−
1
+
α
t
x
t
−
1
2
β
t
+
x
t
−
1
2
−
2
α
ˉ
t
−
1
x
0
x
t
−
1
+
α
ˉ
t
−
1
x
0
2
1
−
α
ˉ
t
−
1
−
(
x
t
−
α
ˉ
t
x
0
)
2
1
−
α
ˉ
t
)
)
=\exp \left(-\frac{1}{2}\left(\frac{\mathbf{x}_t^2-2 \sqrt{\alpha_t} \mathbf{x}_t \mathbf{x}_{t-1}+\alpha_t \mathbf{x}_{t-1}^2}{\beta_t}+\frac{\mathbf{x}_{t-1}^2-2 \sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_0 \mathbf{x}_{t-1}+\bar{\alpha}_{t-1} \mathbf{x}_0^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t} \mathbf{x}_0\right)^2}{1-\bar{\alpha}_t}\right)\right)
=exp(−21(βtxt2−2αtxtxt−1+αtxt−12+1−αˉt−1xt−12−2αˉt−1x0xt−1+αˉt−1x02−1−αˉt(xt−αˉtx0)2))
=
exp
(
−
1
2
(
(
α
t
β
t
+
1
1
−
α
ˉ
t
−
1
)
x
t
−
1
2
−
(
2
α
t
β
t
x
t
+
2
α
ˉ
t
−
1
1
−
α
ˉ
t
−
1
x
0
)
x
t
−
1
+
C
(
x
t
,
x
0
)
)
)
(
16
)
=\exp \left(-\frac{1}{2}\left(\left(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}}\right) \mathbf{x}_{t-1}^2-\left(\frac{2 \sqrt{\alpha_t}}{\beta_t} \mathbf{x}_t+\frac{2 \sqrt{\bar{\alpha}_{t-1}}}{1-\bar{\alpha}_{t-1}} \mathbf{x}_0\right) \mathbf{x}_{t-1}+C\left(\mathbf{x}_t, \mathbf{x}_0\right)\right)\right) \quad(16)
=exp(−21((βtαt+1−αˉt−11)xt−12−(βt2αtxt+1−αˉt−12αˉt−1x0)xt−1+C(xt,x0)))(16)
到这其实可以看出来是个完全平方式,也就是:
=
exp
(
−
1
2
(
A
x
2
−
B
x
+
C
)
)
(
17
)
=\exp(-\frac{1}{2}(Ax^2-Bx+C)) \quad(17)
=exp(−21(Ax2−Bx+C))(17)
又因为我们想把他表示成一个高斯分布且:
exp
(
−
(
x
−
μ
)
2
2
σ
2
)
=
exp
(
−
1
2
(
1
σ
2
x
2
−
2
μ
σ
2
x
+
μ
2
σ
2
)
)
(
18
)
\exp(-\frac{(x-\mu)^2}{2\sigma^2}) = \exp \left(-\frac{1}{2}\left(\frac{1}{\sigma^2} x^2-\frac{2 \mu}{\sigma^2} x+\frac{\mu^2}{\sigma^2}\right)\right) \quad(18)
exp(−2σ2(x−μ)2)=exp(−21(σ21x2−σ22μx+σ2μ2))(18)
对比着(17)(18)来看,我们发现其实A就是
1
σ
2
\frac{1}{\sigma^2}
σ21,是个定值,只需要知道
α
t
\alpha_t
αt就知道
σ
2
\sigma^2
σ2,而
μ
~
t
(
x
t
,
x
0
)
\tilde{\boldsymbol{\mu}}_t\left(\mathbf{x}_t, \mathbf{x}_0\right)
μ~t(xt,x0)从B整合一下就可以得到(对比(16)(18)):
1
σ
t
−
1
2
=
α
t
β
t
+
1
1
−
α
ˉ
t
−
1
(
19
)
\boldsymbol{\frac{1}{\sigma_{t-1}^2}} = \frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}} \quad(19)
σt−121=βtαt+1−αˉt−11(19)
μ
~
t
−
1
(
x
t
−
1
,
x
0
)
=
α
t
(
1
−
α
ˉ
t
−
1
)
1
−
α
ˉ
t
x
t
+
α
ˉ
t
−
1
β
t
1
−
α
ˉ
t
x
0
(
20
)
\tilde{\boldsymbol{\mu}}_{t-1}\left(\mathbf{x}_{t-1}, \mathbf{x}_0\right)=\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} \mathbf{x}_t+\frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1-\bar{\alpha}_t} \mathbf{x}_0 \quad(20)
μ~t−1(xt−1,x0)=1−αˉtαt(1−αˉt−1)xt+1−αˉtαˉt−1βtx0(20)
注意这里的
μ
,
σ
2
\mu,\sigma^2
μ,σ2代表的是逆扩散过程中,我们根据
x
t
x_t
xt猜出来的
x
t
−
1
x_{t-1}
xt−1分布的均值和方差。
是不是觉得,唉,怎么又是这么长的式子,稳住!马上快结束了!!!
不知道大家有没有发现我们求出来的均值(20)有什么不对劲的地方。我们先快速回顾一下我们在解决什么问题。之前的前向扩散是已知
x
t
−
1
x_{t-1}
xt−1去求
x
t
x_t
xt,而我们现在想求的是逆扩散过程,也就是已知
x
t
x_{t}
xt去求
x
t
−
1
x_{t-1}
xt−1,最终目标是能得到
x
0
x_0
x0也就是复原后的图片。但是式子(20)里面告诉我们,求均值需要用到
x
0
x_0
x0,这该咋办呢?只好借助前向扩散过程中的公式(2),用
x
t
x_t
xt估计
x
0
x_0
x0:
公式(2):
x
t
=
α
ˉ
t
x
0
+
1
−
α
ˉ
t
z
t
,
α
ˉ
t
=
α
t
α
t
−
1
α
t
−
2
.
.
.
(
2
)
x_t=\sqrt{\bar{\alpha}_t} x_0+\sqrt{1-\bar{\alpha}_t} {z}_t,\bar{\alpha}_t=\alpha_t \alpha_{t-1} \alpha_{t-2}... \quad(2)
xt=αˉtx0+1−αˉtzt,αˉt=αtαt−1αt−2...(2)
移项可得:
x
0
=
1
α
ˉ
t
(
x
t
−
1
−
α
ˉ
t
z
t
)
(
21
)
\mathbf{x}_0=\frac{1}{\sqrt{\bar{\alpha} t}}\left(\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \mathbf{z}_t\right) \quad(21)
x0=αˉt1(xt−1−αˉtzt)(21)
把公式(21)代入(20)并整理一下:
μ
~
t
−
1
(
x
t
−
1
)
=
1
α
t
(
x
t
−
β
t
1
−
α
ˉ
t
z
t
)
(
22
)
\tilde{\boldsymbol{\mu}}_{t-1}\left(\mathbf{x}_{t-1}\right)=\frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \mathbf{z}_t\right) \quad(22)
μ~t−1(xt−1)=αt1(xt−1−αˉtβtzt)(22)
终于,我们好像知道了如何在已知
x
t
x_t
xt的情况下去估计
x
t
−
1
x_{t-1}
xt−1的分布。不过还剩一个问题,
z
t
z_t
zt是啥啊?这里可不一定是高斯分布了,我们只在前向扩散过程中定义它为服从高斯分布的噪声,可没说在逆扩散过程中也定义为高斯分布啊。那怎么办呢?这时候,该轮到我们的机器学习闪亮登场了!!公式无法推导的,那就上机器学习暴力求解呗!我们需要这样一个模型,输入为
x
t
x_t
xt,输出为噪声
z
t
z_{t}
zt。
那既然用到机器学习,我们就需要标签,这里的标签就是在前向扩散过程的每一次迭代中,我们从高斯分布中采样得到的噪声,这个噪声我们是可以在前向扩散过程中记录下来的,因为前向加了什么噪声我们肯定知道嘛。那在逆扩散过程中,我们只需要把前向扩散过程中对应记录下来的噪声作为我们的标签,就可以训练模型,使得模型根据 x t x_t xt预测 z t z_t zt,从而根据 z t z_t zt计算t-1时刻的分布的均值 (公式(22)),又因为方差是个定值 (公式(19)),我们就可以求得t-1时刻的分布了。
我自己画了个图,帮助理解:
所以总结来说,逆扩散过程就是使预测出来的噪声和前向扩散过程中加的噪声距离越小越好。
终于!!!把前向扩散过程和逆扩散过程都解释完了!!也恭喜大家看到了这里!!撒花!!如果觉得绕也很正常,看多几遍就好啦!!
最后附上论文里的训练和预测流程图:
在training阶段的第五步就是主要用到了公式(2),求得了t时刻的
x
t
x_t
xt,同时值得注意的是,传入模型的还有此刻的t。因为之前提到
α
t
\alpha_t
αt会随着t增大而减小,于是噪声也会加的越来越大,所以把t传入相当于多提供一些信息帮助训练模型来预测噪声。在sampling过程中,step4用到了公式(19)和(22),用预测的噪声算出前一时刻
x
t
−
1
x_{t-1}
xt−1的分布,一步步向前传,直到预测出
x
0
x_0
x0。
总结
扩散模型分为前向扩散过程和逆扩散过程。前向扩散过程是从 x 0 x_0 x0算到 x T x_T xT,逆扩散是从 x T x_T xT算到 x 0 x_0 x0。
在前扩散过程中,主要做的事情就是迭代的往图片上加服从高斯分布的噪声,从 x 0 x_0 x0逐渐加到 x T x_T xT,我们在这一部分逐步推导了公式,核心公式为公式(1)和(2)。公式(1)告诉我们如何从 x t − 1 x_{t-1} xt−1算出 x t x_t xt,而公式(2)告诉我们如何根据 x 0 x_0 x0和时刻t算出 x t x_t xt。
在逆扩散过程中,我们先利用贝叶斯公式,将前一时刻的分布,转化为可利用前向扩散过程中的公式计算的式子,发现转换后的公式主要由三个高斯分布组成,而这三个高斯分布可以组合成一个新的高斯分布,也就是前一时刻的分布。之后经过一系列化简合并,我们将前一时刻的分布的均值和方差求了出来。不过前一时刻分布的均值里面包含了一个我们无法直接求得的 z t z_t zt,所以我们需要借助机器学习去估计这个 z t z_t zt,然后利用预测出来的 z t z_t zt求得前一时刻的分布,然后逐步迭代直到算出 x 0 x_0 x0。核心公式为公式(19)和(22)。
写文不易,如果觉得还不错可以点个赞鼓励一下哦!同时如果文中出现错误,非常欢迎指正!