因CSDN正文字数限制,只能拆分为两个文档,接上文继续
无中生有的目标分布
如果我们有过模型训练经验,那就一定知道,需要有一个目标值或GT值与模型预测值计算损失才能实现训练闭环。在前文中我们说到,DDMP就是对“能基于 x t x_t xt计算出 x t − 1 x_{t-1} xt−1”的逆向分布过程建模,模型训练预测出的分布用 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt−1∣xt)表示,那与预测分布进行比较的目标分布应该如何表示呢?
假设目标分布用 q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q(xt−1∣xt)表示,请问计算出 q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q(xt−1∣xt)这个分布可能吗?理论上,是可能的,但需要遍历整个数据集,通过大量计算,才能实现。如果我们真的大费周章计算出了 q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q(xt−1∣xt),那也没有训练模型的必要了,反正已经知道了目标分布,直接采样就可以了。现在就遇到了一个进退两难的问题,模型训练需要一个目标分布,但目标分布难以获得,并且获得之后就没有训练的必要了。
在深度学习中,特别是生成模型,大多数情况下,训练和推理的起点是不同的。 生成时,目标就是无中生有,是直接从一个随机噪声中生成结果;但训练不一样,我们是从目标分布的数据集开始的,我们是有
x
0
x_0
x0的;
q
(
x
t
−
1
∣
x
t
)
q(x_{t-1}|x_t)
q(xt−1∣xt)做不到的事情,
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(x_{t-1}|x_t,x_0)
q(xt−1∣xt,x0)可以做的到。这是一个顺利成章的想法,因为DDPM的训练过程本来就是先从
x
0
x_0
x0数据加噪到
x
T
x_T
xT,然后再通过模型训练进行去噪。将
x
0
x_0
x0作为条件引入到目标分布中进而解决目标分布未知问题的思想,在后续其他扩散模型、分数模型或流模型等生成范式中都会用到。直接将
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(x_{t-1}|x_t,x_0)
q(xt−1∣xt,x0)拿出来会有些突兀,但其实
q
(
x
t
−
1
∣
x
t
)
q(x_{t-1}|x_t)
q(xt−1∣xt)和
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(x_{t-1}|x_t,x_0)
q(xt−1∣xt,x0)是等价的,因为去噪过程也是马尔可夫过程,即
x
t
−
1
x_{t-1}
xt−1只与
x
t
x_t
xt相关,与
x
0
x_0
x0无关,故将
x
0
x_0
x0添加到条件概率
q
(
x
t
−
1
∣
x
t
)
q(x_{t-1}|x_t)
q(xt−1∣xt)中没有任务影响。虽说
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(x_{t-1}|x_t,x_0)
q(xt−1∣xt,x0)可以作为目标分布,但一眼看上去
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(x_{t-1}|x_t,x_0)
q(xt−1∣xt,x0)还是未知的,就要看下面的公式推导:
q
(
x
t
−
1
∣
x
t
,
x
0
)
=
q
(
x
t
,
x
t
−
1
,
x
0
)
q
(
x
t
,
x
0
)
=
q
(
x
t
∣
x
t
−
1
,
x
0
)
q
(
x
t
−
1
,
x
0
)
q
(
x
t
,
x
0
)
=
q
(
x
t
∣
x
t
−
1
,
x
0
)
q
(
x
t
−
1
∣
x
0
)
q
(
x
0
)
q
(
x
t
∣
x
0
)
q
(
x
0
)
=
q
(
x
t
∣
x
t
−
1
)
q
(
x
t
−
1
∣
x
0
)
q
(
x
t
∣
x
0
)
\begin{align*} q(x_{t-1}|x_t,x_0) & = \frac{q(x_t,x_{t-1},x_0)}{q(x_t,x_0)} \\ & = \frac{q(x_t|x_{t-1},x_0)q(x_{t-1},x_0)}{q(x_t,x_0)} \\ & = q(x_t|x_{t-1},x_0)\frac{q(x_{t-1}|x_0)q(x_0)}{q(x_t|x_0)q(x_0)} \\ & = q(x_t|x_{t-1})\frac{q(x_{t-1}|x_0)}{q(x_t|x_0)} \tag5 \end{align*}
q(xt−1∣xt,x0)=q(xt,x0)q(xt,xt−1,x0)=q(xt,x0)q(xt∣xt−1,x0)q(xt−1,x0)=q(xt∣xt−1,x0)q(xt∣x0)q(x0)q(xt−1∣x0)q(x0)=q(xt∣xt−1)q(xt∣x0)q(xt−1∣x0)(5)
通过上述推导,有一种柳暗花明又一村的感觉,因为
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(x_{t-1}|x_t,x_0)
q(xt−1∣xt,x0)等价三个前向加噪过程组合表征,而我们前面已经推导出了此处所需的所有公式。公式(5)中的三项都符合高斯分布,直接将高斯分布概率密度函数
f
(
x
∣
μ
,
σ
2
)
=
1
2
π
σ
2
exp
(
−
(
x
−
μ
)
2
2
σ
2
)
f(x|\mu,\sigma^2)=\frac{1}{\sqrt{2\pi\sigma^2}}\exp(-\frac{(x-\mu)^2}{2\sigma^2})
f(x∣μ,σ2)=2πσ21exp(−2σ2(x−μ)2)带入计算,有
q
(
x
t
−
1
∣
x
t
,
x
0
)
=
q
(
x
t
∣
x
t
−
1
)
q
(
x
t
−
1
∣
x
0
)
q
(
x
t
∣
x
0
)
=
N
(
x
t
;
α
t
x
t
−
1
,
(
1
−
α
t
)
I
)
N
(
x
t
−
1
;
α
ˉ
t
−
1
x
0
,
(
1
−
α
ˉ
t
−
1
)
I
)
N
(
x
t
;
α
ˉ
t
x
0
,
(
1
−
α
ˉ
t
)
I
)
∝
exp
{
−
[
(
x
t
−
α
t
x
t
−
1
)
2
2
(
1
−
α
t
)
+
(
x
t
−
1
−
α
ˉ
t
−
1
x
0
)
2
2
(
1
−
α
ˉ
t
−
1
)
−
(
x
t
−
α
ˉ
t
x
0
)
2
2
(
1
−
α
ˉ
t
)
]
}
=
exp
{
−
1
2
[
(
x
t
−
α
t
x
t
−
1
)
2
1
−
α
t
+
(
x
t
−
1
−
α
ˉ
t
−
1
x
0
)
2
1
−
α
ˉ
t
−
1
−
(
x
t
−
α
ˉ
t
x
0
)
2
1
−
α
ˉ
t
]
}
=
exp
{
−
1
2
[
(
x
t
2
−
2
α
t
x
t
−
1
x
t
+
α
t
x
t
−
1
2
)
1
−
α
t
+
(
x
t
−
1
2
−
2
α
ˉ
t
−
1
x
t
−
1
x
0
+
α
ˉ
t
−
1
x
0
2
)
1
−
α
ˉ
t
−
1
−
(
x
t
−
α
ˉ
t
x
0
)
2
1
−
α
ˉ
t
]
}
\begin{align*} q(x_{t-1}|x_t,x_0) & = q(x_t|x_{t-1})\frac{q(x_{t-1}|x_0)}{q(x_t|x_0)} \\ & = \frac{N(x_t;\sqrt{\alpha_t}x_{t-1},(1-\alpha_t) I)N(x_{t-1};\sqrt{ \bar{\alpha}_{t-1}} x_0, (1- \bar{\alpha}_{t-1})I)}{N(x_t;\sqrt{ \bar{\alpha}_t} x_0, (1- \bar{\alpha}_t)I)} \\ & \propto \exp \lbrace -[\frac{(x_t-\sqrt{\alpha_t}x_{t-1})^2}{2(1-\alpha_t)} + \frac{(x_{t-1}-\sqrt{ \bar{\alpha}_{t-1}} x_0)^2}{2(1- \bar{\alpha}_{t-1})} - \frac{(x_t-\sqrt{ \bar{\alpha}_t} x_0)^2}{2(1- \bar{\alpha}_t)}] \rbrace \\ & = \exp \lbrace -\frac{1}{2}[\frac{(x_t-\sqrt{\alpha_t}x_{t-1})^2}{1-\alpha_t} + \frac{(x_{t-1}-\sqrt{ \bar{\alpha}_{t-1}} x_0)^2}{1- \bar{\alpha}_{t-1}} - \frac{(x_t-\sqrt{ \bar{\alpha}_t} x_0)^2}{1- \bar{\alpha}_t}] \rbrace \\ & = \exp \lbrace -\frac{1}{2}[\frac{(x_t^2-2\sqrt{\alpha_t}x_{t-1}x_t+\alpha_tx_{t-1}^2 )}{1-\alpha_t} + \frac{(x_{t^-1}^2-2\sqrt{\bar{\alpha}_{t-1}}x_{t-1}x_0+\bar{\alpha}_{t-1}x_0^2)}{1- \bar{\alpha}_{t-1}} - \frac{(x_t-\sqrt{ \bar{\alpha}_t} x_0)^2}{1- \bar{\alpha}_t}] \rbrace \\ \end{align*}
q(xt−1∣xt,x0)=q(xt∣xt−1)q(xt∣x0)q(xt−1∣x0)=N(xt;αˉtx0,(1−αˉt)I)N(xt;αtxt−1,(1−αt)I)N(xt−1;αˉt−1x0,(1−αˉt−1)I)∝exp{−[2(1−αt)(xt−αtxt−1)2+2(1−αˉt−1)(xt−1−αˉt−1x0)2−2(1−αˉt)(xt−αˉtx0)2]}=exp{−21[1−αt(xt−αtxt−1)2+1−αˉt−1(xt−1−αˉt−1x0)2−1−αˉt(xt−αˉtx0)2]}=exp{−21[1−αt(xt2−2αtxt−1xt+αtxt−12)+1−αˉt−1(xt−12−2αˉt−1xt−1x0+αˉt−1x02)−1−αˉt(xt−αˉtx0)2]}
上述推理过程很直接,就是将三个高斯分布用概率密函数表示后,再将两项与
x
t
−
1
x_{t-1}
xt−1相关的平方项解开。为什么不把与
x
t
−
1
x_{t-1}
xt−1无关的最后一项解开呢?因为我们的目的是计算
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(x_{t-1}|x_t,x_0)
q(xt−1∣xt,x0),
x
t
−
1
x_{t-1}
xt−1是目标,而最后一项
(
x
t
−
α
ˉ
t
x
0
)
2
1
−
α
ˉ
t
\frac{(x_t-\sqrt{ \bar{\alpha}_t} x_0)^2}{1- \bar{\alpha}_t}
1−αˉt(xt−αˉtx0)2只与
x
0
x_0
x0、
x
t
x_t
xt、
α
\alpha
α相关,而
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(x_{t-1}|x_t,x_0)
q(xt−1∣xt,x0)是表示已给定
x
0
x_0
x0、
x
t
x_t
xt时
x
t
−
1
x_{t-1}
xt−1的概率,即
x
0
x_0
x0、
x
t
x_t
xt是已知的,
(
x
t
−
α
ˉ
t
x
0
)
2
1
−
α
ˉ
t
\frac{(x_t-\sqrt{ \bar{\alpha}_t} x_0)^2}{1- \bar{\alpha}_t}
1−αˉt(xt−αˉtx0)2相当于是一个常数,本质上对
x
t
−
1
x_{t-1}
xt−1的计算没有影响。上述推导将前两个平方项展开是为了合并,那么此时我们就能想到,应该将以含有
x
t
−
1
x_{t-1}
xt−1的项为公共项进行合并。
q
(
x
t
−
1
∣
x
t
,
x
0
)
∝
exp
{
−
1
2
[
(
x
t
2
−
2
α
t
x
t
−
1
x
t
+
α
t
x
t
−
1
2
)
1
−
α
t
+
(
x
t
−
1
2
−
2
α
ˉ
t
−
1
x
t
−
1
x
0
+
α
ˉ
t
−
1
x
0
2
)
1
−
α
ˉ
t
−
1
−
(
x
t
−
α
ˉ
t
x
0
)
2
1
−
α
ˉ
t
]
}
=
exp
{
−
1
2
[
(
α
t
1
−
α
t
+
1
1
−
α
ˉ
t
−
1
)
x
t
−
1
2
−
(
2
α
t
x
t
1
−
α
t
+
2
α
ˉ
t
−
1
x
0
1
−
α
ˉ
t
−
1
)
x
t
−
1
+
(
x
t
2
1
−
α
t
+
α
ˉ
t
−
1
x
0
2
1
−
α
ˉ
t
−
1
+
(
x
t
−
α
ˉ
t
x
0
)
2
1
−
α
ˉ
t
)
]
}
\begin{align*} q(x_{t-1}|x_t,x_0) & \propto \exp \lbrace -\frac{1}{2}[\frac{(x_t^2-\color{blue}{2\sqrt{\alpha_t}x_{t-1}x_t}\color{grey}+\color{red}\alpha_tx_{t-1}^2\color{grey})}{1-\alpha_t} + \frac{(\color{red}x_{t^-1}^2\color{grey}-\color{blue}2\sqrt{\bar{\alpha}_{t-1}}x_{t-1}x_0\color{grey}+\bar{\alpha}_{t-1}x_0^2)}{1- \bar{\alpha}_{t-1}} - \frac{(x_t-\sqrt{ \bar{\alpha}_t} x_0)^2}{1- \bar{\alpha}_t}] \rbrace \\ & = \exp \lbrace -\frac{1}{2}[(\frac{\alpha_t}{1-\alpha_t}+\frac{1}{1- \bar{\alpha}_{t-1}})x_{t-1}^2-(\frac{2\sqrt{\alpha_t}x_t}{1-\alpha_t}+\frac{2\sqrt{\bar{\alpha}_{t-1}}x_0}{1- \bar{\alpha}_{t-1}})x_{t-1} + (\frac{x_t^2}{1-\alpha_t}+\frac{\bar{\alpha}_{t-1}x_0^2}{1- \bar{\alpha}_{t-1}}+\frac{(x_t-\sqrt{ \bar{\alpha}_t} x_0)^2}{1- \bar{\alpha}_t})] \rbrace \\ \end{align*}
q(xt−1∣xt,x0)∝exp{−21[1−αt(xt2−2αtxt−1xt+αtxt−12)+1−αˉt−1(xt−12−2αˉt−1xt−1x0+αˉt−1x02)−1−αˉt(xt−αˉtx0)2]}=exp{−21[(1−αtαt+1−αˉt−11)xt−12−(1−αt2αtxt+1−αˉt−12αˉt−1x0)xt−1+(1−αtxt2+1−αˉt−1αˉt−1x02+1−αˉt(xt−αˉtx0)2)]}
上述推导过程是将包含
x
t
−
1
x_{t-1}
xt−1和
x
t
−
1
2
x_{t-1}^2
xt−12的同类项合并,得到前两项,然后剩下与
x
t
−
1
x_{t-1}
xt−1无关组的项成第三项,前文已经讨论了,与
x
t
−
1
x_{t-1}
xt−1无关的项,其实是一个常数,可用
C
(
x
t
,
x
0
)
C(x_t,x_0)
C(xt,x0)表示,并且后续可通过正比于/
∝
\propto
∝关系直接忽略此项。
q
(
x
t
−
1
∣
x
t
,
x
0
)
∝
exp
{
−
1
2
[
(
α
t
1
−
α
t
+
1
1
−
α
ˉ
t
−
1
)
x
t
−
1
2
−
(
2
α
t
x
t
1
−
α
t
+
2
α
ˉ
t
−
1
x
0
1
−
α
ˉ
t
−
1
)
x
t
−
1
+
C
(
x
t
,
x
0
)
]
}
∝
exp
{
−
1
2
[
(
α
t
1
−
α
t
+
1
1
−
α
ˉ
t
−
1
)
x
t
−
1
2
−
2
(
α
t
x
t
1
−
α
t
+
α
ˉ
t
−
1
x
0
1
−
α
ˉ
t
−
1
)
x
t
−
1
]
}
=
exp
{
−
1
2
[
α
t
(
1
−
α
ˉ
t
−
1
)
+
1
−
α
t
(
1
−
α
t
)
(
1
−
α
ˉ
t
−
1
)
x
t
−
1
2
−
2
(
α
t
x
t
1
−
α
t
+
α
ˉ
t
−
1
x
0
1
−
α
ˉ
t
−
1
)
x
t
−
1
]
}
=
exp
{
−
1
2
[
α
t
−
α
ˉ
t
+
1
−
α
t
(
1
−
α
t
)
(
1
−
α
ˉ
t
−
1
)
x
t
−
1
2
−
2
(
α
t
x
t
1
−
α
t
+
α
ˉ
t
−
1
x
0
1
−
α
ˉ
t
−
1
)
x
t
−
1
]
}
=
exp
{
−
1
2
[
1
−
α
ˉ
t
(
1
−
α
t
)
(
1
−
α
ˉ
t
−
1
)
x
t
−
1
2
−
2
(
α
t
x
t
1
−
α
t
+
α
ˉ
t
−
1
x
0
1
−
α
ˉ
t
−
1
)
x
t
−
1
]
}
=
exp
{
−
1
2
(
1
−
α
ˉ
t
(
1
−
α
t
)
(
1
−
α
ˉ
t
−
1
)
)
[
x
t
−
1
2
−
2
(
α
t
x
t
1
−
α
t
+
α
ˉ
t
−
1
x
0
1
−
α
ˉ
t
−
1
)
1
−
α
ˉ
t
(
1
−
α
t
)
(
1
−
α
ˉ
t
−
1
)
x
t
−
1
]
}
=
exp
{
−
1
2
(
1
−
α
ˉ
t
(
1
−
α
t
)
(
1
−
α
ˉ
t
−
1
)
)
[
x
t
−
1
2
−
2
(
α
t
x
t
1
−
α
t
+
α
ˉ
t
−
1
x
0
1
−
α
ˉ
t
−
1
)
(
1
−
α
t
)
(
1
−
α
ˉ
t
−
1
)
1
−
α
ˉ
t
x
t
−
1
]
}
=
exp
{
−
1
2
(
1
(
1
−
α
t
)
(
1
−
α
ˉ
t
−
1
)
1
−
α
ˉ
t
)
[
x
t
−
1
2
−
2
α
t
(
1
−
α
ˉ
t
−
1
)
x
t
+
α
ˉ
t
−
1
(
1
−
α
t
)
x
0
1
−
α
ˉ
t
x
t
−
1
]
}
∝
N
(
x
t
−
1
;
α
t
(
1
−
α
ˉ
t
−
1
)
x
t
+
α
ˉ
t
−
1
(
1
−
α
t
)
x
0
1
−
α
ˉ
t
⏟
μ
q
(
x
t
,
x
0
)
,
(
1
−
α
t
)
(
1
−
α
ˉ
t
−
1
)
1
−
α
ˉ
t
I
⏟
Σ
q
(
t
)
)
\begin{align*} q(x_{t-1}|x_t,x_0) & \propto \exp \lbrace -\frac{1}{2}[(\frac{\alpha_t}{1-\alpha_t}+\frac{1}{1- \bar{\alpha}_{t-1}})x_{t-1}^2-(\frac{2\sqrt{\alpha_t}x_t}{1-\alpha_t}+\frac{2\sqrt{\bar{\alpha}_{t-1}}x_0}{1- \bar{\alpha}_{t-1}})x_{t-1} + C(x_t,x_0)] \rbrace \\ & \propto \exp \lbrace -\frac{1}{2}[(\frac{\alpha_t}{1-\alpha_t}+\frac{1}{1- \bar{\alpha}_{t-1}})x_{t-1}^2-2(\frac{\sqrt{\alpha_t}x_t}{1-\alpha_t}+\frac{\sqrt{\bar{\alpha}_{t-1}}x_0}{1- \bar{\alpha}_{t-1}})x_{t-1}] \rbrace \\ & = \exp \lbrace -\frac{1}{2}[\frac{\alpha_t(1- \bar{\alpha}_{t-1})+1-\alpha_t}{(1-\alpha_t)(1- \bar{\alpha}_{t-1})}x_{t-1}^2-2(\frac{\sqrt{\alpha_t}x_t}{1-\alpha_t}+\frac{\sqrt{\bar{\alpha}_{t-1}}x_0}{1- \bar{\alpha}_{t-1}})x_{t-1}] \rbrace \\ & = \exp \lbrace -\frac{1}{2}[\frac{\alpha_t- \bar{\alpha}_t+1-\alpha_t}{(1-\alpha_t)(1- \bar{\alpha}_{t-1})}x_{t-1}^2-2(\frac{\sqrt{\alpha_t}x_t}{1-\alpha_t}+\frac{\sqrt{\bar{\alpha}_{t-1}}x_0}{1- \bar{\alpha}_{t-1}})x_{t-1}] \rbrace \\ & = \exp \lbrace -\frac{1}{2}[\frac{1- \bar{\alpha}_t}{(1-\alpha_t)(1- \bar{\alpha}_{t-1})}x_{t-1}^2-2(\frac{\sqrt{\alpha_t}x_t}{1-\alpha_t}+\frac{\sqrt{\bar{\alpha}_{t-1}}x_0}{1- \bar{\alpha}_{t-1}})x_{t-1}] \rbrace \\ & = \exp \lbrace -\frac{1}{2}(\frac{1- \bar{\alpha}_t}{(1-\alpha_t)(1- \bar{\alpha}_{t-1})})[x_{t-1}^2-2\frac{(\frac{\sqrt{\alpha_t}x_t}{1-\alpha_t}+\frac{\sqrt{\bar{\alpha}_{t-1}}x_0}{1- \bar{\alpha}_{t-1}})}{\frac{1- \bar{\alpha}_t}{(1-\alpha_t)(1- \bar{\alpha}_{t-1})}}x_{t-1}] \rbrace \\ & = \exp \lbrace -\frac{1}{2}(\frac{1- \bar{\alpha}_t}{(1-\alpha_t)(1- \bar{\alpha}_{t-1})})[x_{t-1}^2-2\frac{(\frac{\sqrt{\alpha_t}x_t}{1-\alpha_t}+\frac{\sqrt{\bar{\alpha}_{t-1}}x_0}{1- \bar{\alpha}_{t-1}})(1-\alpha_t)(1- \bar{\alpha}_{t-1})}{1- \bar{\alpha}_t}x_{t-1}] \rbrace \\ & = \exp \lbrace -\frac{1}{2}(\frac{1}{\frac{(1-\alpha_t)(1- \bar{\alpha}_{t-1})}{1- \bar{\alpha}_t}})[x_{t-1}^2-2\frac{\sqrt{\alpha_t}(1- \bar{\alpha}_{t-1})x_t+\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)x_0}{1- \bar{\alpha}_t}x_{t-1}] \rbrace \\ & \propto N(x_{t-1};\underbrace{\frac{\sqrt{\alpha_t}(1- \bar{\alpha}_{t-1})x_t+\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)x_0}{1- \bar{\alpha}_t}}_{\mu_q(x_t,x_0)},\underbrace{\frac{(1-\alpha_t)(1- \bar{\alpha}_{t-1})}{1- \bar{\alpha}_t}I}_{\Sigma_q(t)}) \tag6 \end{align*}
q(xt−1∣xt,x0)∝exp{−21[(1−αtαt+1−αˉt−11)xt−12−(1−αt2αtxt+1−αˉt−12αˉt−1x0)xt−1+C(xt,x0)]}∝exp{−21[(1−αtαt+1−αˉt−11)xt−12−2(1−αtαtxt+1−αˉt−1αˉt−1x0)xt−1]}=exp{−21[(1−αt)(1−αˉt−1)αt(1−αˉt−1)+1−αtxt−12−2(1−αtαtxt+1−αˉt−1αˉt−1x0)xt−1]}=exp{−21[(1−αt)(1−αˉt−1)αt−αˉt+1−αtxt−12−2(1−αtαtxt+1−αˉt−1αˉt−1x0)xt−1]}=exp{−21[(1−αt)(1−αˉt−1)1−αˉtxt−12−2(1−αtαtxt+1−αˉt−1αˉt−1x0)xt−1]}=exp{−21((1−αt)(1−αˉt−1)1−αˉt)[xt−12−2(1−αt)(1−αˉt−1)1−αˉt(1−αtαtxt+1−αˉt−1αˉt−1x0)xt−1]}=exp{−21((1−αt)(1−αˉt−1)1−αˉt)[xt−12−21−αˉt(1−αtαtxt+1−αˉt−1αˉt−1x0)(1−αt)(1−αˉt−1)xt−1]}=exp{−21(1−αˉt(1−αt)(1−αˉt−1)1)[xt−12−21−αˉtαt(1−αˉt−1)xt+αˉt−1(1−αt)x0xt−1]}∝N(xt−1;μq(xt,x0)
1−αˉtαt(1−αˉt−1)xt+αˉt−1(1−αt)x0,Σq(t)
1−αˉt(1−αt)(1−αˉt−1)I)(6)
从上述公式(6)的上一步推导出公式(6)其实是少写了一个常数项,与 f ( x ∣ μ , σ 2 ) = 1 2 π σ 2 exp ( − ( x − μ ) 2 2 σ 2 ) = 1 2 π σ 2 exp ( − x 2 − 2 μ x + μ 2 2 σ 2 ) f(x|\mu,\sigma^2)=\frac{1}{\sqrt{2\pi\sigma^2}}\exp(-\frac{(x-\mu)^2}{2\sigma^2})=\frac{1}{\sqrt{2\pi\sigma^2}}\exp(-\frac{x^2-2\mu x+\mu^2}{2\sigma^2}) f(x∣μ,σ2)=2πσ21exp(−2σ2(x−μ)2)=2πσ21exp(−2σ2x2−2μx+μ2)对比可知,公式(6)上一行中只有包含 x t − 1 2 x_{t-1}^2 xt−12和 x t − 1 x_{t-1} xt−1两项,分别对应标准公式中的 x 2 x^2 x2和 2 μ x 2\mu x 2μx,但少了 μ 2 \mu^2 μ2这一项,但其实因为最终的 μ \mu μ只与 x 0 x_0 x0、 x t x_t xt、 α \alpha α相关,是一个常数项,故完全可以通过加一减一的操作构建一个 μ 2 \mu^2 μ2来,再通过正比于/ ∝ \propto ∝关系忽略多出的一个常数项。
经过上述复杂的推导过程,得出的结论是, q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt−1∣xt,x0)也符合高斯分布,其均值记为 μ q ( x t , x 0 ) \mu_q(x_t,x_0) μq(xt,x0),是一个 x t , x 0 x_t,x_0 xt,x0的函数,训练过程中 x t , x 0 x_t,x_0 xt,x0是已知的,故可以计算出 μ q ( x t , x 0 ) \mu_q(x_t,x_0) μq(xt,x0);方差记为 Σ q ( t ) \Sigma_q(t) Σq(t),仅与 α \alpha α有关,在DDPM中 α t = 1 − β t \alpha_t = 1-\beta_t αt=1−βt, β t \beta_t βt是人为设置的超参数,方差 Σ q ( t ) \Sigma_q(t) Σq(t)可直接计算。
至此,通过大量推导,我们得到了一个可以作为训练目标且可解的分布,即公式(6)。
设计感强但有效的损失函数
在深度学习中,损失函数是很重要的一环,算法研究人员往往会从训练目标出发构建各式各样的损失函数,了解早期用GAN实现图片风格迁移的读者应该对此有强烈的感受。损失函数的构建,开发者解释说明时一般都带有假设性,实际均是从结果出发,发表出来的首先得效果好,其次才是基于最终使用得损失函数进行补充性解释。基于公式(6)中,有以下结论:
μ
q
(
x
t
,
x
0
)
=
α
t
(
1
−
α
ˉ
t
−
1
)
x
t
+
α
ˉ
t
−
1
(
1
−
α
t
)
x
0
1
−
α
ˉ
t
\mu_q(x_t,x_0)=\frac{\sqrt{\alpha_t}(1- \bar{\alpha}_{t-1})x_t+\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)x_0}{1- \bar{\alpha}_t}
μq(xt,x0)=1−αˉtαt(1−αˉt−1)xt+αˉt−1(1−αt)x0
Σ
q
(
t
)
=
(
1
−
α
t
)
(
1
−
α
ˉ
t
−
1
)
1
−
α
ˉ
t
I
=
σ
q
2
(
t
)
\Sigma_q(t)=\frac{(1-\alpha_t)(1- \bar{\alpha}_{t-1})}{1- \bar{\alpha}_t}I=\sigma_q^2(t)
Σq(t)=1−αˉt(1−αt)(1−αˉt−1)I=σq2(t)
训练目标是将模型预测分布
p
θ
(
x
t
−
1
∣
x
t
)
p_{\theta}(x_{t-1}|x_t)
pθ(xt−1∣xt)尽可能接近目标分布
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(x_{t-1}|x_t,x_0)
q(xt−1∣xt,x0),此时就体现出构建损失函数的假设性,即
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(x_{t-1}|x_t,x_0)
q(xt−1∣xt,x0)是高斯分布,那么显然
p
θ
(
x
t
−
1
∣
x
t
)
p_{\theta}(x_{t-1}|x_t)
pθ(xt−1∣xt)也应该是一个高斯分布,将其均值记为
μ
θ
\mu_{\theta}
μθ,方差记为
Σ
θ
\Sigma_{\theta}
Σθ。前文已提到
Σ
q
(
t
)
\Sigma_q(t)
Σq(t)是一个不需要模型学习的参数,故可直接假设
Σ
q
(
t
)
=
Σ
θ
\Sigma_q(t)=\Sigma_{\theta}
Σq(t)=Σθ。DDPM论文作者认为此方差学习与否,对最终效果影响不大,故没有使用模型进行学习。现有
p
θ
(
x
t
−
1
∣
x
t
)
∼
N
(
x
t
−
1
;
μ
θ
,
Σ
q
(
t
)
)
p_{\theta}(x_{t-1}|x_t) \sim N(x_{t-1};\mu_{\theta},\Sigma_q(t))
pθ(xt−1∣xt)∼N(xt−1;μθ,Σq(t))
现在模型预测分布
p
θ
(
x
t
−
1
∣
x
t
)
p_{\theta}(x_{t-1}|x_t)
pθ(xt−1∣xt)和目标分布
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(x_{t-1}|x_t,x_0)
q(xt−1∣xt,x0)的均值和方差都已确定,可通过计算KL散度来表征两个分布之间的差距。根据前半部分的前置知识中提到的两个高斯分布的KL散度计算公式:
P
∼
N
(
μ
1
,
σ
1
2
I
)
P \sim N(\mu_1,\sigma_1^2I)
P∼N(μ1,σ12I)和
Q
∼
N
(
μ
2
,
σ
2
2
I
)
Q \sim N(\mu_2,\sigma_2^2I)
Q∼N(μ2,σ22I),
D
K
L
(
P
∣
∣
Q
)
=
log
σ
2
σ
1
+
σ
1
2
+
(
μ
1
−
μ
2
)
2
2
σ
2
2
−
1
2
D_{KL}(P||Q) = \log\frac{\sigma_2}{\sigma_1}+\frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2}-\frac{1}{2}
DKL(P∣∣Q)=logσ1σ2+2σ22σ12+(μ1−μ2)2−21,有
D
K
L
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
∣
∣
p
θ
(
x
t
−
1
∣
x
t
)
)
=
D
K
L
(
N
(
x
t
−
1
;
μ
q
,
Σ
q
(
t
)
)
∣
∣
N
(
x
t
−
1
;
μ
θ
,
Σ
q
(
t
)
)
)
=
log
Σ
q
(
t
)
Σ
q
(
t
)
+
Σ
q
(
t
)
+
(
μ
q
−
μ
θ
)
2
2
Σ
q
(
t
)
−
1
2
=
log
1
+
1
2
+
(
μ
q
−
μ
θ
)
2
2
Σ
q
(
t
)
−
1
2
=
(
μ
q
−
μ
θ
)
2
2
Σ
q
(
t
)
=
1
2
σ
q
2
(
t
)
[
∣
∣
μ
q
−
μ
θ
∣
∣
2
2
]
\begin{align*} D_{KL}(q(x_{t-1}|x_t,x_0)||p_{\theta}(x_{t-1}|x_t)) & = D_{KL}(N(x_{t-1};\mu_q,\Sigma_q(t))||N(x_{t-1};\mu_{\theta},\Sigma_q(t))) \\ & = \log \frac{\sqrt{\Sigma_q(t)}}{\sqrt{\Sigma_q(t)}} + \frac{\Sigma_q(t)+(\mu_q-\mu_{\theta})^2}{2\Sigma_q(t)} - \frac{1}{2} \\ & = \log 1 + \frac{1}{2} + \frac{(\mu_q-\mu_{\theta})^2}{2\Sigma_q(t)} - \frac{1}{2} \\ & = \frac{(\mu_q-\mu_{\theta})^2}{2\Sigma_q(t)} \\ & = \frac{1}{2\sigma_q^2(t)}[||\mu_q-\mu_{\theta}||^2_2] \tag7 \end{align*}
DKL(q(xt−1∣xt,x0)∣∣pθ(xt−1∣xt))=DKL(N(xt−1;μq,Σq(t))∣∣N(xt−1;μθ,Σq(t)))=logΣq(t)Σq(t)+2Σq(t)Σq(t)+(μq−μθ)2−21=log1+21+2Σq(t)(μq−μθ)2−21=2Σq(t)(μq−μθ)2=2σq2(t)1[∣∣μq−μθ∣∣22](7)
又是经过一系列公式推导后,我们得到了公式(7),将两个高斯分布的KL散度简化为了两个分布的均值之间的差异。其实还有一部分基于VAE损失函数ELBO进行拆分、推理的过程,本文将该部分省略,省略后对理解没有任何影响,甚至能提高理解性。
继续推导,公式(7)虽然已用于训练,但是与DDPM中的去噪毫无关系,表明我们还需继续推导。目前有
μ
q
(
x
t
,
x
0
)
=
α
t
(
1
−
α
ˉ
t
−
1
)
x
t
+
α
ˉ
t
−
1
(
1
−
α
t
)
x
0
1
−
α
ˉ
t
\mu_q(x_t,x_0)=\frac{\sqrt{\alpha_t}(1- \bar{\alpha}_{t-1})x_t+\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)x_0}{1- \bar{\alpha}_t}
μq(xt,x0)=1−αˉtαt(1−αˉt−1)xt+αˉt−1(1−αt)x0,其实关于
x
t
x_t
xt和
x
0
x_0
x0的函数,训练时
x
0
x_0
x0是已知的,但预测时
x
0
x_0
x0是未知的,那应该如何而来呢?此时,就要继续用到之前推导的公式
q
(
x
t
∣
x
0
)
∼
N
(
x
t
;
α
ˉ
t
x
0
,
(
1
−
α
ˉ
t
)
I
)
q(x_t|x_0) \sim N(x_t;\sqrt{ \bar{\alpha}_t} x_0, (1- \bar{\alpha}_t)I)
q(xt∣x0)∼N(xt;αˉtx0,(1−αˉt)I),基于其有:
x
t
=
α
ˉ
t
x
0
+
(
1
−
α
ˉ
t
)
ϵ
t
,
α
ˉ
t
=
∏
1
t
α
i
,
ϵ
t
∼
N
(
0
,
I
)
x_t = \sqrt{ \bar{\alpha}_t}x_0+\sqrt{(1- \bar{\alpha}_t)}\epsilon_t,\bar{\alpha}_t=\prod_1^t \alpha_i,\epsilon_t \sim N(0,I)
xt=αˉtx0+(1−αˉt)ϵt,αˉt=1∏tαi,ϵt∼N(0,I)
将上式反过来,就能通过
x
t
x_t
xt计算出
x
0
x_0
x0,即
x
0
=
x
t
−
1
−
α
ˉ
t
ϵ
t
α
ˉ
t
,
ϵ
t
∼
N
(
0
,
I
)
(8)
x_0=\frac{x_t-\sqrt{1- \bar{\alpha}_t}\epsilon_t}{\sqrt{ \bar{\alpha}_t}},\epsilon_t \sim N(0,I) \tag8
x0=αˉtxt−1−αˉtϵt,ϵt∼N(0,I)(8)
将公式(8)带入
μ
q
(
x
t
,
x
0
)
\mu_q(x_t,x_0)
μq(xt,x0)的计算公式有,
μ
q
(
x
t
,
x
0
)
=
α
t
(
1
−
α
ˉ
t
−
1
)
x
t
+
α
ˉ
t
−
1
(
1
−
α
t
)
x
0
1
−
α
ˉ
t
=
α
t
(
1
−
α
ˉ
t
−
1
)
x
t
+
α
ˉ
t
−
1
(
1
−
α
t
)
x
t
−
1
−
α
ˉ
t
ϵ
t
α
ˉ
t
1
−
α
ˉ
t
=
α
t
(
1
−
α
ˉ
t
−
1
)
x
t
1
−
α
ˉ
t
+
(
1
−
α
t
)
x
t
(
1
−
α
ˉ
t
)
α
t
−
(
1
−
α
t
)
1
−
α
ˉ
t
ϵ
t
(
1
−
α
ˉ
t
)
α
t
=
(
α
t
(
1
−
α
ˉ
t
−
1
)
1
−
α
ˉ
t
+
1
−
α
t
(
1
−
α
ˉ
t
)
α
t
)
x
t
−
(
1
−
α
t
)
1
−
α
ˉ
t
ϵ
t
(
1
−
α
ˉ
t
)
α
t
=
(
α
t
(
1
−
α
ˉ
t
−
1
)
(
1
−
α
ˉ
t
)
α
t
+
1
−
α
t
(
1
−
α
ˉ
t
)
α
t
)
x
t
−
(
1
−
α
t
)
ϵ
t
1
−
α
ˉ
t
α
t
=
α
t
−
α
ˉ
t
+
1
−
α
t
(
1
−
α
ˉ
t
)
α
t
x
t
−
(
1
−
α
t
)
ϵ
t
1
−
α
ˉ
t
α
t
=
1
−
α
ˉ
t
(
1
−
α
ˉ
t
)
α
t
x
t
−
(
1
−
α
t
)
1
−
α
ˉ
t
α
t
ϵ
t
=
1
α
t
x
t
−
(
1
−
α
t
)
1
−
α
ˉ
t
α
t
ϵ
t
\begin{align*} \mu_q(x_t,x_0)&=\frac{\sqrt{\alpha_t}(1- \bar{\alpha}_{t-1})x_t+\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)x_0}{1- \bar{\alpha}_t} \\ &=\frac{\sqrt{\alpha_t}(1- \bar{\alpha}_{t-1})x_t+\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)\frac{x_t-\sqrt{1- \bar{\alpha}_t}\epsilon_t}{\sqrt{ \bar{\alpha}_t}}}{1- \bar{\alpha}_t} \\ &=\frac{\sqrt{\alpha_t}(1- \bar{\alpha}_{t-1})x_t}{1- \bar{\alpha}_t}+\frac{(1-\alpha_t)x_t}{(1- \bar{\alpha}_t)\sqrt{\alpha_t}}-\frac{(1-\alpha_t)\sqrt{1- \bar{\alpha}_t}\epsilon_t}{(1- \bar{\alpha}_t)\sqrt{\alpha_t}} \\ &= (\frac{\sqrt{\alpha_t}(1- \bar{\alpha}_{t-1})}{1- \bar{\alpha}_t}+\frac{1-\alpha_t}{(1- \bar{\alpha}_t)\sqrt{\alpha_t}})x_t-\frac{(1-\alpha_t)\sqrt{1- \bar{\alpha}_t}\epsilon_t}{(1- \bar{\alpha}_t)\sqrt{\alpha_t}} \\ &= (\frac{\alpha_t(1- \bar{\alpha}_{t-1})}{(1- \bar{\alpha}_t)\sqrt{\alpha_t}}+\frac{1-\alpha_t}{(1- \bar{\alpha}_t)\sqrt{\alpha_t}})x_t-\frac{(1-\alpha_t)\epsilon_t}{\sqrt{1- \bar{\alpha}_t}\sqrt{\alpha_t}} \\ &= \frac{\alpha_t- \bar{\alpha}_t+1-\alpha_t}{(1- \bar{\alpha}_t)\sqrt{\alpha_t}}x_t-\frac{(1-\alpha_t)\epsilon_t}{\sqrt{1- \bar{\alpha}_t}\sqrt{\alpha_t}} \\ &= \frac{1- \bar{\alpha}_t}{(1- \bar{\alpha}_t)\sqrt{\alpha_t}}x_t-\frac{(1-\alpha_t)}{\sqrt{1- \bar{\alpha}_t}\sqrt{\alpha_t}}\epsilon_t \\ &= \frac{1}{\sqrt{\alpha_t}}x_t-\frac{(1-\alpha_t)}{\sqrt{1- \bar{\alpha}_t}\sqrt{\alpha_t}}\epsilon_t \tag9 \end{align*}
μq(xt,x0)=1−αˉtαt(1−αˉt−1)xt+αˉt−1(1−αt)x0=1−αˉtαt(1−αˉt−1)xt+αˉt−1(1−αt)αˉtxt−1−αˉtϵt=1−αˉtαt(1−αˉt−1)xt+(1−αˉt)αt(1−αt)xt−(1−αˉt)αt(1−αt)1−αˉtϵt=(1−αˉtαt(1−αˉt−1)+(1−αˉt)αt1−αt)xt−(1−αˉt)αt(1−αt)1−αˉtϵt=((1−αˉt)αtαt(1−αˉt−1)+(1−αˉt)αt1−αt)xt−1−αˉtαt(1−αt)ϵt=(1−αˉt)αtαt−αˉt+1−αtxt−1−αˉtαt(1−αt)ϵt=(1−αˉt)αt1−αˉtxt−1−αˉtαt(1−αt)ϵt=αt1xt−1−αˉtαt(1−αt)ϵt(9)
公式(9)中已没有
x
0
x_0
x0,噪声
ϵ
t
\epsilon_t
ϵt也出现了,逐渐与去噪过程联系起来。还是上面的思路,既然目标分布的均值
μ
q
\mu_q
μq是公式(9)的形式,为了尽可能保证模型预测分布与目标分布距离小,在已将均值设置为相等的前提下,将模型预测分布的均值
μ
θ
\mu_{\theta}
μθ也按照公式(9)的形式构建,即
μ
θ
=
μ
θ
(
x
t
,
t
)
=
1
α
t
x
t
−
1
−
α
t
1
−
α
ˉ
t
α
t
ϵ
θ
^
(
x
t
,
t
)
(10)
\mu_{\theta}=\mu_{\theta}(x_t,t)=\frac{1}{\sqrt{\alpha_t}}x_t-\frac{1-\alpha_t}{\sqrt{1- \bar{\alpha}_t}\sqrt{\alpha_t}}\hat{\epsilon_{\theta}}(x_t,t) \tag{10}
μθ=μθ(xt,t)=αt1xt−1−αˉtαt1−αtϵθ^(xt,t)(10)
公式(9)、(10)带入到公式(7)中继续计算
D
K
L
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
∣
∣
p
θ
(
x
t
−
1
∣
x
t
)
)
=
1
2
σ
q
2
(
t
)
[
∣
∣
μ
q
−
μ
θ
∣
∣
2
2
]
=
1
2
σ
q
2
(
t
)
[
∣
∣
1
α
t
x
t
−
1
−
α
t
1
−
α
ˉ
t
α
t
ϵ
t
−
1
α
t
x
t
+
1
−
α
t
1
−
α
ˉ
t
α
t
ϵ
θ
^
(
x
t
,
t
)
∣
∣
2
2
]
=
1
2
σ
q
2
(
t
)
[
∣
∣
−
1
−
α
t
1
−
α
ˉ
t
α
t
ϵ
t
+
1
−
α
t
1
−
α
ˉ
t
α
t
ϵ
θ
^
(
x
t
,
t
)
∣
∣
2
2
]
=
1
2
σ
q
2
(
t
)
[
∣
∣
−
1
−
α
t
1
−
α
ˉ
t
α
t
(
ϵ
t
−
ϵ
θ
^
(
x
t
,
t
)
)
∣
∣
2
2
]
=
1
2
σ
q
2
(
t
)
(
1
−
α
t
)
2
(
1
−
α
ˉ
t
)
α
t
[
∣
∣
ϵ
t
−
ϵ
θ
^
(
x
t
,
t
)
∣
∣
2
2
]
\begin{align*} & D_{KL}(q(x_{t-1}|x_t,x_0)||p_{\theta}(x_{t-1}|x_t)) \\ & = \frac{1}{2\sigma_q^2(t)}[||\mu_q-\mu_{\theta}||^2_2] \\ & = \frac{1}{2\sigma_q^2(t)}[||\frac{1}{\sqrt{\alpha_t}}x_t-\frac{1-\alpha_t}{\sqrt{1- \bar{\alpha}_t}\sqrt{\alpha_t}}\epsilon_t-\frac{1}{\sqrt{\alpha_t}}x_t+\frac{1-\alpha_t}{\sqrt{1- \bar{\alpha}_t}\sqrt{\alpha_t}}\hat{\epsilon_{\theta}}(x_t,t)||^2_2] \\ & = \frac{1}{2\sigma_q^2(t)}[||-\frac{1-\alpha_t}{\sqrt{1- \bar{\alpha}_t}\sqrt{\alpha_t}}\epsilon_t+\frac{1-\alpha_t}{\sqrt{1- \bar{\alpha}_t}\sqrt{\alpha_t}}\hat{\epsilon_{\theta}}(x_t,t)||^2_2] \\ & = \frac{1}{2\sigma_q^2(t)}[||-\frac{1-\alpha_t}{\sqrt{1- \bar{\alpha}_t}\sqrt{\alpha_t}}(\epsilon_t-\hat{\epsilon_{\theta}}(x_t,t))||^2_2] \\ & = \frac{1}{2\sigma_q^2(t)}\frac{(1-\alpha_t)^2}{(1- \bar{\alpha}_t)\alpha_t}[||\epsilon_t-\hat{\epsilon_{\theta}}(x_t,t)||^2_2] \tag{11} \end{align*}
DKL(q(xt−1∣xt,x0)∣∣pθ(xt−1∣xt))=2σq2(t)1[∣∣μq−μθ∣∣22]=2σq2(t)1[∣∣αt1xt−1−αˉtαt1−αtϵt−αt1xt+1−αˉtαt1−αtϵθ^(xt,t)∣∣22]=2σq2(t)1[∣∣−1−αˉtαt1−αtϵt+1−αˉtαt1−αtϵθ^(xt,t)∣∣22]=2σq2(t)1[∣∣−1−αˉtαt1−αt(ϵt−ϵθ^(xt,t))∣∣22]=2σq2(t)1(1−αˉt)αt(1−αt)2[∣∣ϵt−ϵθ^(xt,t)∣∣22](11)
至此,我们完成了损失函数构建的所有推导,DDPM训练时就是将公式(11)作为损失函数,其中
ϵ
t
\epsilon_t
ϵt是前向过程中从第
t
−
1
t-1
t−1步到第
t
t
t步时向样本数据中添加的随机噪声,而
ϵ
θ
^
(
x
t
,
t
)
\hat{\epsilon_{\theta}}(x_t,t)
ϵθ^(xt,t)则是模型预测的噪声,将分布预测任务转为了噪声预测任务,即DDPM最终的训练本质是以前向过程加噪的真实噪声为目标对象,通过模型预测前向过程加噪的噪声量,将其与真实加噪比较、计算损失,实现训练闭环。
训练与采样过程
训练过程
- 从高斯分布 N ∼ ( 0 , I ) N \sim (0,I) N∼(0,I)中随机采样一个噪声 ϵ t \epsilon_t ϵt,从{1,2,…,T}中随机采样一个时间步 t t t;注意每个时刻都是独立重新采样新的随机高斯噪声,不同时刻 t , ϵ t,\epsilon t,ϵ是不一样的值
- 将随机采样的 ϵ t , t \epsilon_t,t ϵt,t带入公式 x t = α ˉ t x 0 + ( 1 − α ˉ t ) ϵ t x_t = \sqrt{ \bar{\alpha}_t}x_0+\sqrt{(1- \bar{\alpha}_t)}\epsilon_t xt=αˉtx0+(1−αˉt)ϵt可以基于 x 0 x_0 x0计算出 x t x_t xt; x 0 x_0 x0是已知的训练数据
- 将 x t , t x_t,t xt,t输入神经网络模型,模型预测值 ϵ θ ^ ( x t , t ) \hat{\epsilon_{\theta}}(x_t,t) ϵθ^(xt,t)即模型预测的噪声
- 计算 ϵ \epsilon ϵ和 ϵ θ ^ ( x t , t ) \hat{\epsilon_{\theta}}(x_t,t) ϵθ^(xt,t)之间的平方差误差,通过梯度计算、反向传播优化模型参数,最小化此平方误差
基于上述步骤对模型进行训练,训练结束之后,整个模型就相当于是
ϵ
θ
^
(
x
t
,
t
)
\hat{\epsilon_{\theta}}(x_t,t)
ϵθ^(xt,t),即模型能够预测前向过程中从
x
t
−
1
x_{t-1}
xt−1到
x
t
x_t
xt添加的噪声。从直觉上讲,模型预测每一步的噪声比直接预测
x
0
x_0
x0要容易很多,DDPM论文就是证明了此想法。以下是训练伪代码:
采样过程
- 从高斯分布 N ∼ ( 0 , I ) N \sim (0,I) N∼(0,I)中随机采样一个纯噪声 x t x_t xt; x t x_t xt为采样过程起点
- 从时间步T开始,逐渐按顺序循环执行以下步骤到时间步1;与训练过程不同,采样过程中时间步是按逆向过程一步一步执行的,训练过程中时间步是随机采样
2.1. 从高斯分布 N ∼ ( 0 , I ) N \sim (0,I) N∼(0,I)中随机采样一个噪声 z z z
2.2. 使用训练后的模型预测出前向过程中从 x t − 1 x_{t-1} xt−1到 x t x_t xt添加的噪声: ϵ ~ = ϵ θ ^ ( x t , t ) \tilde{\epsilon}=\hat{\epsilon_{\theta}}(x_t,t) ϵ~=ϵθ^(xt,t)
2.3. 使用公式(8)计算 x 0 x_0 x0的估计值 x 0 ~ = x t − 1 − α ˉ t ϵ ~ α ˉ t \tilde{x_0}=\frac{x_t-\sqrt{1- \bar{\alpha}_t}\tilde{\epsilon}}{\sqrt{ \bar{\alpha}_t}} x0~=αˉtxt−1−αˉtϵ~
2.4. 使用公式 μ q ( x t , x 0 ) = α t ( 1 − α ˉ t − 1 ) x t + α ˉ t − 1 ( 1 − α t ) x 0 1 − α ˉ t \mu_q(x_t,x_0)=\frac{\sqrt{\alpha_t}(1- \bar{\alpha}_{t-1})x_t+\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)x_0}{1- \bar{\alpha}_t} μq(xt,x0)=1−αˉtαt(1−αˉt−1)xt+αˉt−1(1−αt)x0计算目标分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt−1∣xt,x0)的均值估计值 μ ~ = μ t ( x t , x 0 ~ ) \tilde{\mu}=\mu_t({x_t,\tilde{x_0}}) μ~=μt(xt,x0~)
2.5. 至此,当前时间步目标分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt−1∣xt,x0)的均值和方差均已知,可直接采样出 x t − 1 x_{t-1} xt−1,即 x t − 1 = μ ~ + σ t z x_{t-1}=\tilde{\mu}+\sigma_tz xt−1=μ~+σtz - 上述第二步骤循环执行到时间步1时结束执行,获得最终的预测值 x 0 x_0 x0
以下是采样伪代码:
总结
本文通过大量公式推导,完整熟悉了DDPM的数学原理,其中涉及到的数学知识基本都是高中数学,详细读者略微用心就能完全搞清楚整个扩散模型原理。有的读者可能会觉得,只要会用不就好了,为什么要花这么多时间了解原理细节;但本文中的大量原理是与我们在使用stable diffusion生图时息息相关的。如为什么会stable diffusion webui中有那么多采样方法,为什么comfyui中有采样器和调度器,这些就是与扩散模型中的采样方式、加噪方式相关;为什么comfyui文生图时需要先初始化一个空的Latent Image,这是因为文生图是就是从噪声开始的。相信本文能提高读者对扩散模型生成原理的理解程度,进而提高对stable diffusion模型的理解。