前言
近几年来,以对抗生成网络(GAN)和变分自编码器(VAE)为主流的生成模型一直作为研究的热点,这两种模型有各自的特点,也有其本身的原理特性所带来的不足,对这两种模型有过一定了解的一般会有这样的认知:以图像生成为例,GAN的优势在于可以生成更加逼真的图像,但是容易出现生成图大部分局限在同一个模式下,不能够很好的概括数据的分布,除此之外,GAN有时难以训练,一个好的GAN模型通常需要繁琐的参数调整,正则化等等。相对地,VAE这种以数据分布的对数似然(log-likelihood)为逼近目标的模式,要比GAN能够生成更加多样的图像,但在图像质量上还是逊色于GAN模型。
有没有模型可以两者兼得呢?
diffusion models似乎给这个目标带来了可能,这个模型的最初版本在2015年被提出,直到2020年的一篇论文对该模型做了大幅度的改进,扩散模型逐渐获得关注,这两年关于该模型的研究越来越多,尤其是openAI(就是那个开发chatGBT的公司)对扩散模型的改进做了一些标志性贡献,使这个模型有了一些品牌效应。
本文介绍扩散模型最初的两篇论文,包括公式推导,以及一些公式背后的物理直觉,最后尽可能的简单而全面的给出模型实现过程的程序。
论文:
《Deep Unsupervised Learning using Nonequilibrium Thermodynamics》
《Denoising Diffusion Probabilistic Models》
基本思想
对于一个数据集(数据分布为
p
(
x
)
p(\mathbf{x})
p(x)),如果以马尔科夫过程的形式不断添加高斯噪声,缓慢地破坏数据分布的结构,如果该过程足够缓慢,最后数据集的分布可以认为是一个均质的高斯分布(在各个维度都一样),这个过程也称为扩散。随后,给定一张完全从高斯分布中采样的图像,能否根据扩散过程的已知参数,反向推导出一开始的数据分布
p
(
x
)
p(\mathbf{x})
p(x)?
为了便于表示,我们将从原始数据分布扩散到完全的高斯噪声的过程称为前向扩散过程(forward diffusion process),将由高斯噪声反推到原始图像分布的过程称为反向扩散过程(reverse diffusion process)。
符号定义
将原始图像定义为
x
0
\mathbf{x}_0
x0,扩散过程的
t
t
t时刻对应的图像定义为
x
t
\mathbf{x}_t
xt,从初始图像到完全扩散为高斯噪声所经历的时间为
T
T
T,对应
x
T
\mathbf{x}_T
xT,不同的论文中对
T
T
T的设置有所不同,下图给出了一个形象化表示的示例:
定义
q
(
x
t
∣
x
t
−
1
)
q \left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)
q(xt∣xt−1)为前向过程的转移概率分布,也就是说
x
t
−
1
\mathbf{x}_{t-1}
xt−1总是对应噪声较少的图像,同时定义
p
θ
(
x
t
−
1
∣
x
t
)
p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)
pθ(xt−1∣xt)为反向过程的转移概率分布,接下来,前向过程的转移分布的具体形式为:
q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)=\mathcal{N}\left(\mathbf{x}_t; \sqrt{1-\beta_t} \mathbf{x}_{t-1}, \beta_t I\right) q(xt∣xt−1)=N(xt;1−βtxt−1,βtI)
这个公式这样去理解,首先,
x
t
\mathbf{x}_t
xt的分布可以写成是以
x
t
−
1
\mathbf{x}_{t-1}
xt−1为条件的分布,这个分布是一个高斯的形式,它的均值为
1
−
β
t
x
t
−
1
\sqrt{1-\beta_t} \mathbf{x}_{t-1}
1−βtxt−1,方差为
β
t
I
\beta_t I
βtI,在马尔可夫过程中我们知道,给定一个状态的概率分布,通过转移概率分布可以计算出下一个状态的概率分布,这里的不同时刻的
x
x
x就是不同的状态。
根据这个公式,如果要得到最终的
x
T
x_T
xT,我们需要将上面的公式不停迭代
T
T
T次,实际上,有一种更直接的方法可以直接获得任意时刻
t
t
t所对应的图像:
我们重新给出一个定义:
α
t
=
1
−
β
t
α
ˉ
t
=
∏
s
=
1
t
a
s
\begin{aligned} \alpha_t & =1-\beta_t \\ \bar{\alpha}_t & =\prod_{s=1}^t a_s \end{aligned}
αtαˉt=1−βt=s=1∏tas
接下来,用这个新的定义重写上式,已知
x
t
x_t
xt对应的高斯分布的均值和方差,通过重参数(reparameterization trick),可以将
x
t
\mathbf{x}_t
xt表示成如下的形式:
x
t
=
1
−
β
t
x
t
−
1
+
β
t
ϵ
=
α
t
x
t
−
1
+
1
−
α
t
ϵ
\begin{aligned} \mathbf{x}_t &=\sqrt{1-\beta_t} \mathbf{x}_{t-1}+\sqrt{\beta_t}\boldsymbol{\epsilon} \\ & =\sqrt{\alpha_t} \mathbf{x}_{t-1}+\sqrt{1-\alpha_t} \boldsymbol{\epsilon} \end{aligned}
xt=1−βtxt−1+βtϵ=αtxt−1+1−αtϵ
如果对重参数不懂的话可以查阅资料,这个很好理解,不做过多解释。
ϵ
∼
N
(
0
,
I
)
,
N
(
μ
,
σ
2
)
=
μ
+
σ
ϵ
\epsilon\sim\mathcal{N}(0, I), \mathcal{N}(\mu, \sigma^2)=\mu + \sigma\epsilon
ϵ∼N(0,I),N(μ,σ2)=μ+σϵ。
另一方面,根据递推可以知道
x
t
−
1
\mathbf{x}_{t-1}
xt−1和
x
t
−
2
\mathbf{x}_{t-2}
xt−2的关系:
x
t
−
1
=
α
t
−
1
x
t
−
2
+
1
−
α
t
−
1
ϵ
\mathbf{x}_{t-1}=\sqrt{\alpha_{t-1}} \mathbf{x}_{t-2}+\sqrt{1-\alpha_{t-1}} \boldsymbol\epsilon
xt−1=αt−1xt−2+1−αt−1ϵ
为了避免混淆,我们暂时用
ϵ
∗
\boldsymbol\epsilon^*
ϵ∗表示
N
(
0
,
I
)
\mathcal{N}(0, I)
N(0,I),代入上式,可以得到:
x
t
=
α
t
x
t
−
1
+
1
−
α
t
ϵ
∗
=
α
t
(
α
t
−
1
x
t
−
2
+
1
−
α
t
−
1
ϵ
∗
)
+
1
−
α
t
ϵ
∗
=
α
t
α
t
−
1
x
t
−
2
+
α
t
1
−
α
t
−
1
ϵ
∗
+
1
−
α
t
ϵ
∗
\begin{aligned} \mathbf{x}_t &=\sqrt{\alpha_t} \mathbf{x}_{t-1}+\sqrt{1-\alpha_t} \boldsymbol\epsilon^* \\ &=\sqrt{\alpha_t}(\sqrt{\alpha_{t-1}} \mathbf{x}_{t-2}+\sqrt{1-\alpha_{t-1}} \boldsymbol\epsilon^*)+\sqrt{1-\alpha_t} \boldsymbol\epsilon^* \\ &=\sqrt{\alpha_t}\sqrt{\alpha_{t-1}} \mathbf{x}_{t-2} + \sqrt{\alpha_t}\sqrt{1-\alpha_{t-1}} \boldsymbol\epsilon^* + \sqrt{1-\alpha_t} \boldsymbol\epsilon^* \end{aligned}
xt=αtxt−1+1−αtϵ∗=αt(αt−1xt−2+1−αt−1ϵ∗)+1−αtϵ∗=αtαt−1xt−2+αt1−αt−1ϵ∗+1−αtϵ∗
我们将后两项合并,带有权重的两个标准正太分布的和:
α
t
1
−
α
t
−
1
ϵ
∗
+
1
−
α
t
ϵ
∗
\begin{aligned} & \sqrt{\alpha_t}\sqrt{1-\alpha_{t-1}} \boldsymbol\epsilon^* + \sqrt{1-\alpha_t} \boldsymbol\epsilon^* \\ \end{aligned}
αt1−αt−1ϵ∗+1−αtϵ∗
已知
A
=
C
I
+
D
I
A=C\mathrm{I} + D\mathrm{I}
A=CI+DI,那么
σ
A
2
=
σ
C
2
+
σ
D
2
,
μ
A
=
μ
C
+
μ
D
\sigma_{A}^2=\sigma_{C}^2+\sigma_{D}^2, \mu_A=\mu_C + \mu_D
σA2=σC2+σD2,μA=μC+μD,所以上式对应的新的
σ
2
=
(
α
t
1
−
α
t
−
1
)
2
+
(
1
−
α
t
)
2
=
α
t
(
1
−
α
t
−
1
)
+
1
−
α
t
=
1
−
α
t
α
t
−
1
\begin{aligned} \sigma^2&=(\sqrt{\alpha_t}\sqrt{1-\alpha_{t-1}})^2+(\sqrt{1-\alpha_t})^2 \\ &= \alpha_t (1-\alpha_{t-1}) + 1-\alpha_t \\ & = 1 - \alpha_t \alpha_{t-1} \end{aligned}
σ2=(αt1−αt−1)2+(1−αt)2=αt(1−αt−1)+1−αt=1−αtαt−1
写成一个新的高斯分布:
α
t
1
−
α
t
−
1
ϵ
∗
+
1
−
α
t
ϵ
∗
=
1
−
α
t
α
t
−
1
ϵ
∗
\begin{aligned} & \sqrt{\alpha_t}\sqrt{1-\alpha_{t-1}} \boldsymbol\epsilon^* + \sqrt{1-\alpha_t} \boldsymbol\epsilon^* \\ &=\sqrt{1 - \alpha_t \alpha_{t-1}} \boldsymbol\epsilon^* \end{aligned}
αt1−αt−1ϵ∗+1−αtϵ∗=1−αtαt−1ϵ∗
我们将
ϵ
∗
\boldsymbol\epsilon^*
ϵ∗换回
ϵ
\boldsymbol\epsilon
ϵ,带入
x
t
\mathbf{x}_t
xt,可以得到:
x
t
=
α
t
x
t
−
1
+
1
−
α
t
ϵ
∗
=
α
t
α
t
−
1
x
t
−
2
+
1
−
α
t
α
t
−
1
ϵ
\begin{aligned} \mathbf{x}_t &=\sqrt{\alpha_t} \mathbf{x}_{t-1}+\sqrt{1-\alpha_t} \boldsymbol\epsilon^* \\ &= \sqrt{\alpha_t}\sqrt{\alpha_{t-1}} \mathbf{x}_{t-2} + \sqrt{1 - \alpha_t \alpha_{t-1}} \boldsymbol\epsilon \end{aligned}
xt=αtxt−1+1−αtϵ∗=αtαt−1xt−2+1−αtαt−1ϵ
如果再次不停地递推,用
x
t
−
3
\mathbf{x}_{t-3}
xt−3表示
x
t
−
2
\mathbf{x}_{t-2}
xt−2,一直到
x
0
\mathbf{x}_{0}
x0,就可以得到下面的表示:
x
t
=
α
t
α
t
−
1
x
t
−
2
+
1
−
α
t
α
t
−
1
ϵ
=
α
t
α
t
−
1
α
t
−
2
x
t
−
3
+
1
−
α
t
α
t
−
1
α
t
−
2
ϵ
=
α
t
α
t
−
1
…
α
1
α
0
x
0
+
1
−
α
t
α
t
−
1
…
α
1
α
0
ϵ
=
α
ˉ
t
x
0
+
1
−
α
ˉ
t
ϵ
\begin{aligned} \mathbf{x}_t & = \sqrt{\alpha_t}\sqrt{\alpha_{t-1}} \mathbf{x}_{t-2} + \sqrt{1 - \alpha_t \alpha_{t-1}} \boldsymbol\epsilon \\ & =\sqrt{\alpha_t \alpha_{t-1} \alpha_{t-2}} \mathbf{x}_{t-3}+\sqrt{1-\alpha_t \alpha_{t-1} \alpha_{t-2}} \boldsymbol\epsilon \\ & =\sqrt{\alpha_t \alpha_{t-1} \ldots \alpha_1 \alpha_0} \mathbf{x}_0+\sqrt{1-\alpha_t \alpha_{t-1} \ldots \alpha_1 \alpha_0} \boldsymbol\epsilon \\ & =\sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \boldsymbol\epsilon \end{aligned}
xt=αtαt−1xt−2+1−αtαt−1ϵ=αtαt−1αt−2xt−3+1−αtαt−1αt−2ϵ=αtαt−1…α1α0x0+1−αtαt−1…α1α0ϵ=αˉtx0+1−αˉtϵ
这样,就得到了直接用
x
0
\mathbf{x}_0
x0表示
t
t
t时刻的图像
x
t
\mathbf{x}_t
xt的公式:
q
(
x
t
∣
x
0
)
=
N
(
x
t
;
α
ˉ
t
x
0
,
1
−
α
ˉ
t
I
)
q\left(\mathbf{x}_t \mid \mathbf{x}_{0}\right)=\mathcal{N}\left(\mathbf{x}_t;\sqrt{\bar{\alpha}_t} \mathbf{x}_0, 1-\bar{\alpha}_t I\right)
q(xt∣x0)=N(xt;αˉtx0,1−αˉtI)
这个公式很重要,因为它可以让我们一步计算出
t
t
t时刻所对应的图像
x
t
\mathbf{x}_t
xt,而不必通过
q
(
x
t
∣
x
t
−
1
)
q \left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)
q(xt∣xt−1)逐步推导过来,在
T
T
T很大时,可以很明显的减少计算量。
接下来,我们的目标是恢复
x
0
\mathbf{x}_0
x0,用数学的语言来描述,就是我们需要找到一组参数构成的
p
θ
(
x
t
−
1
∣
x
t
)
p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)
pθ(xt−1∣xt),最大化
p
θ
(
x
0
)
p_{\theta}(\mathbf{x}_0)
pθ(x0)的概率,在对数空间和损失函数形式下,我们需要找到能够使
−
log
p
θ
(
x
0
)
-\log p_{\theta}(\mathbf{x}_0)
−logpθ(x0)最小的参数
θ
\theta
θ。再具体一点,在整个函数空间中,存在着无数组的
θ
\theta
θ,这些
θ
\theta
θ经过
p
θ
(
x
t
−
1
∣
x
t
)
p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)
pθ(xt−1∣xt)递推计算时,会得到不同的
p
θ
p_{\theta}
pθ,而总有一组对应的
p
θ
p_{\theta}
pθ可以在
x
0
\mathbf{x}_0
x0时有最大值,我们要找的就是这一组。
损失函数
根据ELBO(变分下界),最小化
−
log
p
θ
(
x
0
)
-\log p_{\theta}(\mathbf{x}_0)
−logpθ(x0)可以等价地表示成如下的形式:
−
log
(
p
θ
(
x
0
)
)
≤
−
log
(
p
θ
(
x
0
)
)
+
D
K
L
(
q
(
x
1
:
T
∣
x
0
)
∥
p
θ
(
x
1
:
T
∣
x
0
)
)
-\log \left(p_\theta\left(\mathbf{x}_0\right)\right) \leq-\log \left(p_\theta\left(\mathbf{x}_0\right)\right)+D_{K L}\left(q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)\right)
−log(pθ(x0))≤−log(pθ(x0))+DKL(q(x1:T∣x0)∥pθ(x1:T∣x0))
后面的
D
K
L
(
⋅
)
D_{K L}(\cdot)
DKL(⋅)表示散度,是一个大于等于零的数,所以对
−
log
p
θ
(
x
0
)
-\log p_{\theta}(\mathbf{x}_0)
−logpθ(x0)的最小化可以转化为对后式的最小化,但是这个式子仍然有不可以直接计算的
−
log
(
p
θ
(
x
0
)
)
-\log \left(p_\theta\left(\mathbf{x}_0\right)\right)
−log(pθ(x0)),我们需要利用后面消除这一项,下面我们将后面展开
友情提示:接下来的推导很繁琐(math heavy)
D
K
L
(
q
(
x
1
:
T
∣
x
0
)
∥
p
θ
(
x
1
:
T
∣
x
0
)
)
=
E
q
[
q
(
x
1
:
T
∣
x
0
)
log
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
1
:
T
∣
x
0
)
]
=
q
(
x
1
:
T
∣
x
0
)
E
q
[
log
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
1
:
T
∣
x
0
)
]
\begin{aligned} &D_{K L}\left(q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)\right) & \\ &= \mathbb{E}_q\left[ q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right) \log\frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)} \right] \\ & = q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right) \mathbb{E}_q\left[ \log\frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)} \right] \end{aligned}
DKL(q(x1:T∣x0)∥pθ(x1:T∣x0))=Eq[q(x1:T∣x0)logpθ(x1:T∣x0)q(x1:T∣x0)]=q(x1:T∣x0)Eq[logpθ(x1:T∣x0)q(x1:T∣x0)]
根据贝叶斯准则,可以将中括号里的对数似然接着展开:
p
θ
(
x
1
:
T
∣
x
0
)
=
p
θ
(
x
0
∣
x
1
:
T
)
p
θ
(
x
1
:
T
)
p
θ
(
x
0
)
=
p
θ
(
x
0
,
x
1
:
T
)
p
θ
(
x
0
)
=
p
θ
(
x
0
:
T
)
p
θ
(
x
0
)
\begin{aligned} p_\theta\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right) &=\frac{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_{1: T}\right) p_\theta\left(\mathbf{x}_{1: T}\right)}{p_\theta\left(\mathbf{x}_0\right)} \\ &= \frac{p_\theta\left(\mathbf{x}_0, \mathbf{x}_{1: T}\right)}{p_\theta\left(\mathbf{x}_0\right)} \\ & = \frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{p_\theta\left(\mathbf{x}_0\right)} \end{aligned}
pθ(x1:T∣x0)=pθ(x0)pθ(x0∣x1:T)pθ(x1:T)=pθ(x0)pθ(x0,x1:T)=pθ(x0)pθ(x0:T)
这样一来,可以重写
log
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
1
:
T
∣
x
0
)
\log\frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}
logpθ(x1:T∣x0)q(x1:T∣x0):
log
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
1
:
T
∣
x
0
)
=
log
(
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
0
:
T
)
p
θ
(
x
0
)
)
=
log
(
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
0
:
T
)
)
+
log
(
p
θ
(
x
0
)
)
\begin{aligned} \log\frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)} &= \log \left(\frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{\frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{p_\theta\left(\mathbf{x}_0\right)}}\right) \\ &= \log \left(\frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right)+\log \left(p_\theta\left(\mathbf{x}_0\right)\right) \end{aligned}
logpθ(x1:T∣x0)q(x1:T∣x0)=log
pθ(x0)pθ(x0:T)q(x1:T∣x0)
=log(pθ(x0:T)q(x1:T∣x0))+log(pθ(x0))
下面,我们对
D
K
L
(
⋅
)
D_{K L}(\cdot)
DKL(⋅)的内容做一些更改:
q
(
x
1
:
T
∣
x
0
)
q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)
q(x1:T∣x0)是一个常数,将其省略,并且我们暂时不考虑最外面的
E
q
[
⋅
]
\mathbb{E}_q[\cdot]
Eq[⋅],(
E
q
[
log
(
p
θ
(
x
0
)
]
=
log
(
p
θ
(
x
0
)
\mathbb{E}_q[\log (p_\theta(\mathbf{x}_0)]=\log (p_\theta(\mathbf{x}_0)
Eq[log(pθ(x0)]=log(pθ(x0)),原来的变分下界可以重新表示为:
−
log
(
p
θ
(
x
0
)
)
+
log
(
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
0
:
T
)
)
+
log
(
p
θ
(
x
0
)
)
=
log
(
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
0
:
T
)
)
\begin{aligned} & -\log \left(p_\theta\left(\mathbf{x}_0\right)\right)+\log \left(\frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right)+\log \left(p_\theta\left(\mathbf{x}_0\right)\right) \\ =& \log \left(\frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right) \end{aligned}
=−log(pθ(x0))+log(pθ(x0:T)q(x1:T∣x0))+log(pθ(x0))log(pθ(x0:T)q(x1:T∣x0))
注意,我们知道所有的
q
(
x
1
:
T
∣
x
0
)
q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)
q(x1:T∣x0),接下来我们将上式进一步展开:
log
(
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
0
:
T
)
)
=
log
(
∏
t
=
1
T
q
(
x
t
∣
x
t
−
1
)
p
(
x
T
)
∏
t
=
1
T
p
θ
(
x
t
−
1
∣
x
t
)
)
=
−
log
(
p
(
x
T
)
)
+
log
(
∏
t
=
1
T
q
(
x
t
∣
x
t
−
1
)
∏
t
=
1
T
p
θ
(
x
t
−
1
∣
x
t
)
)
=
−
log
(
p
(
x
T
)
)
+
∑
t
=
1
T
log
(
q
(
x
t
∣
x
t
−
1
)
p
θ
(
x
t
−
1
∣
x
t
)
)
=
−
log
(
p
(
x
T
)
)
+
∑
t
=
2
T
log
q
(
x
t
∣
x
t
−
1
)
p
θ
(
x
t
−
1
∣
x
t
)
+
log
q
(
x
1
∣
x
0
)
p
θ
(
x
0
∣
x
1
)
\begin{aligned} \log \left(\frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right) & =\log \left(\frac{\prod_{t=1}^T q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}{p\left(\mathbf{x}_T\right) \prod_{t=1}^T p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}\right) \\ & =-\log \left(p\left(\mathbf{x}_T\right)\right)+\log \left(\frac{\prod_{t=1}^T q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}{\prod_{t=1}^T p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}\right) \\ & =-\log \left(p\left(\mathbf{x}_T\right)\right)+\sum_{t=1}^T \log \left(\frac{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}\right) \\ & =-\log \left(p\left(\mathbf{x}_T\right)\right)+\sum_{t=2}^T \log \frac{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}+\log \frac{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}\ \end{aligned}
log(pθ(x0:T)q(x1:T∣x0))=log(p(xT)∏t=1Tpθ(xt−1∣xt)∏t=1Tq(xt∣xt−1))=−log(p(xT))+log(∏t=1Tpθ(xt−1∣xt)∏t=1Tq(xt∣xt−1))=−log(p(xT))+t=1∑Tlog(pθ(xt−1∣xt)q(xt∣xt−1))=−log(p(xT))+t=2∑Tlogpθ(xt−1∣xt)q(xt∣xt−1)+logpθ(x0∣x1)q(x1∣x0)
注意
q
(
x
t
∣
x
t
−
1
)
q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)
q(xt∣xt−1)是已知量,我们可以使用贝叶斯的方式推导出
q
(
x
t
−
1
∣
x
t
)
q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)
q(xt−1∣xt),作者在这里并没有试图直接表示
q
(
x
t
−
1
∣
x
t
)
q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)
q(xt−1∣xt),而是试图寻找
q
(
x
t
−
1
∣
x
t
,
x
0
)
q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_0\right)
q(xt−1∣xt,x0)的表示形式,我们可以从直觉上理解这样表示的原因:
给定一张全部为高斯噪声的
T
T
T时刻的图像,想要预测前一个时刻的图像是很难的:
但是如果同时知道
x
t
,
x
0
\mathbf{x}_{t}, \mathbf{x}_0
xt,x0,就容易预测前一个时刻的分布:
根据马尔科夫过程的马尔科夫属性,某时刻的状态只与上一个时刻的状态有关,而与之前的无关,我们可以重新表示
q
(
x
t
∣
x
t
−
1
)
q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)
q(xt∣xt−1):
q
(
x
t
∣
x
t
−
1
)
=
q
(
x
t
∣
x
t
−
1
,
x
0
)
q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right) =q\left(\mathbf{x}_{t} \mid \mathbf{x}_{t-1}, \mathbf{x}_0\right)
q(xt∣xt−1)=q(xt∣xt−1,x0)
我们看一下如果将我们要找的
q
(
x
t
−
1
∣
x
t
,
x
0
)
q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_0\right)
q(xt−1∣xt,x0)转到上式,需要经过怎样的变换:
q
(
x
t
−
1
∣
x
t
,
x
0
)
=
q
(
x
t
−
1
,
x
t
,
x
0
)
q
(
x
t
,
x
0
)
=
q
(
x
t
∣
x
t
−
1
,
x
0
)
⋅
q
(
x
t
−
1
,
x
0
)
q
(
x
t
,
x
0
)
=
q
(
x
t
∣
x
t
−
1
)
⋅
q
(
x
t
−
1
,
x
0
)
q
(
x
t
,
x
0
)
=
q
(
x
t
∣
x
t
−
1
)
⋅
q
(
x
t
−
1
∣
x
0
)
⋅
q
(
x
0
)
q
(
x
t
∣
x
0
)
⋅
q
(
x
0
)
=
q
(
x
t
∣
x
t
−
1
)
⋅
q
(
x
t
−
1
∣
x
0
)
q
(
x
t
∣
x
0
)
\begin{aligned} q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_0\right) &= \frac{q\left(\mathbf{x}_{t-1} , \mathbf{x}_{t}, \mathbf{x}_0\right)}{q\left(\mathbf{x}_{t}, \mathbf{x}_0\right)} \\ &= \frac{q\left(\mathbf{x}_{t} \mid \mathbf{x}_{t-1}, \mathbf{x}_0\right) \cdot q\left(\mathbf{x}_{t-1}, \mathbf{x}_0\right)}{q\left(\mathbf{x}_{t}, \mathbf{x}_0\right)} \\ & = \frac{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right) \cdot q\left(\mathbf{x}_{t-1}, \mathbf{x}_0\right)}{q\left(\mathbf{x}_{t}, \mathbf{x}_0\right)} \\ &= \frac{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right) \cdot q\left(\mathbf{x}_{t-1}\mid \mathbf{x}_0\right) \cdot q\left(\mathbf{x}_0\right)}{q\left(\mathbf{x}_{t}\mid \mathbf{x}_0\right) \cdot q\left(\mathbf{x}_0\right)} \\ & = \frac{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right) \cdot q\left(\mathbf{x}_{t-1}\mid \mathbf{x}_0\right) }{q\left(\mathbf{x}_{t}\mid \mathbf{x}_0\right)} \end{aligned}
q(xt−1∣xt,x0)=q(xt,x0)q(xt−1,xt,x0)=q(xt,x0)q(xt∣xt−1,x0)⋅q(xt−1,x0)=q(xt,x0)q(xt∣xt−1)⋅q(xt−1,x0)=q(xt∣x0)⋅q(x0)q(xt∣xt−1)⋅q(xt−1∣x0)⋅q(x0)=q(xt∣x0)q(xt∣xt−1)⋅q(xt−1∣x0)
因此,重新用
q
(
x
t
−
1
∣
x
t
,
x
0
)
q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}, \mathbf{x}_0\right)
q(xt−1∣xt,x0)表示
q
(
x
t
∣
x
t
−
1
)
q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)
q(xt∣xt−1):
q
(
x
t
∣
x
t
−
1
)
=
q
(
x
t
−
1
∣
x
t
,
x
0
)
q
(
x
t
∣
x
0
)
q
(
x
t
−
1
∣
x
0
)
q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)=\frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}
q(xt∣xt−1)=q(xt−1∣x0)q(xt−1∣xt,x0)q(xt∣x0)
将上面的结果替换到
log
(
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
0
:
T
)
)
\log \left(\frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right)
log(pθ(x0:T)q(x1:T∣x0))中:
log
(
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
0
:
T
)
)
=
−
log
(
p
(
x
T
)
)
+
∑
t
=
2
T
log
q
(
x
t
∣
x
t
−
1
)
p
θ
(
x
t
−
1
∣
x
t
)
+
log
q
(
x
1
∣
x
0
)
p
θ
(
x
0
∣
x
1
)
=
−
log
(
p
(
x
T
)
)
+
∑
t
=
2
T
log
q
(
x
t
−
1
∣
x
t
,
x
0
)
q
(
x
t
∣
x
0
)
p
θ
(
x
t
−
1
∣
x
t
)
q
(
x
t
−
1
∣
x
0
)
+
log
q
(
x
1
∣
x
0
)
p
θ
(
x
0
∣
x
1
)
=
−
log
(
p
(
x
T
)
)
+
∑
t
=
2
T
log
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
θ
(
x
t
−
1
∣
x
t
)
)
+
∑
t
=
2
T
log
(
q
(
x
t
∣
x
0
)
q
(
x
t
−
1
∣
x
0
)
)
+
log
(
q
(
x
1
∣
x
0
)
p
θ
(
x
0
∣
x
1
)
)
\begin{aligned} \log \left(\frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right) & =-\log \left(p\left(\mathbf{x}_T\right)\right)+\sum_{t=2}^T \log \frac{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}+\log \frac{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)} \\ &= -\log \left(p\left(\mathbf{x}_T\right)\right)+\sum_{t=2}^T \log \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)} +\log \frac{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)} \\ &= -\log \left(p\left(\mathbf{x}_T\right)\right)+\sum_{t=2}^T \log \left(\frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}\right)+\sum_{t=2}^T \log \left(\frac{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}\right)+\log \left(\frac{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}\right) \end{aligned}
log(pθ(x0:T)q(x1:T∣x0))=−log(p(xT))+t=2∑Tlogpθ(xt−1∣xt)q(xt∣xt−1)+logpθ(x0∣x1)q(x1∣x0)=−log(p(xT))+t=2∑Tlogpθ(xt−1∣xt)q(xt−1∣x0)q(xt−1∣xt,x0)q(xt∣x0)+logpθ(x0∣x1)q(x1∣x0)=−log(p(xT))+t=2∑Tlog(pθ(xt−1∣xt)q(xt−1∣xt,x0))+t=2∑Tlog(q(xt−1∣x0)q(xt∣x0))+log(pθ(x0∣x1)q(x1∣x0))
我们先短暂的来看一下为什么最后那一项要单独拿出来:
q
(
x
1
∣
x
0
)
p
θ
(
x
0
∣
x
1
)
=
q
(
x
0
∣
x
1
,
x
0
)
q
(
x
1
∣
x
0
)
q
(
x
0
∣
x
0
)
\frac{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)} = \frac{q\left(\mathbf{x}_0 \mid \mathbf{x}_1, \mathbf{x}_0\right) q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_0 \mid \mathbf{x}_0\right)}
pθ(x0∣x1)q(x1∣x0)=q(x0∣x0)q(x0∣x1,x0)q(x1∣x0)
这里面的
q
(
x
0
∣
x
1
,
x
0
)
q\left(\mathbf{x}_0 \mid \mathbf{x}_1, \mathbf{x}_0\right)
q(x0∣x1,x0)和
q
(
x
0
∣
x
0
)
q\left(\mathbf{x}_0 \mid \mathbf{x}_0\right)
q(x0∣x0)没有什么意义,所以单独考虑。接下来,我们将注意力放在中间两项上:
我们将第三项展开,以
T
=
4
T=4
T=4为例:
∑
t
=
2
4
log
q
(
x
t
∣
x
0
)
q
(
x
t
−
1
∣
x
0
)
=
log
∏
t
=
2
4
q
(
x
t
∣
x
0
)
q
(
x
t
−
1
∣
x
0
)
=
log
q
(
x
2
∣
x
0
)
q
(
x
3
∣
x
0
)
q
(
x
4
∣
x
0
)
q
(
x
1
∣
x
0
)
q
(
x
2
∣
x
0
)
q
(
x
3
∣
x
0
)
=
log
q
(
x
4
∣
x
0
)
q
(
x
1
∣
x
0
)
\begin{aligned} \sum_{t=2}^4 \log \frac{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)} &= \log \prod_{t=2}^4 \frac{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)} \\ & =\log \frac{q\left(\mathbf{x}_2 \mid \mathbf{x}_0\right) q\left(\mathbf{x}_3 \mid \mathbf{x}_0\right) q\left(\mathbf{x}_4 \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right) q\left(\mathbf{x}_2 \mid \mathbf{x}_0\right) q\left(\mathbf{x}_3 \mid \mathbf{x}_0\right)} \\ & = \log \frac{ q\left(\mathbf{x}_4 \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right) } \end{aligned}
t=2∑4logq(xt−1∣x0)q(xt∣x0)=logt=2∏4q(xt−1∣x0)q(xt∣x0)=logq(x1∣x0)q(x2∣x0)q(x3∣x0)q(x2∣x0)q(x3∣x0)q(x4∣x0)=logq(x1∣x0)q(x4∣x0)
所以这一项实际上是可以化简的,我们再次重写
log
(
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
0
:
T
)
)
\log \left(\frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right)
log(pθ(x0:T)q(x1:T∣x0)):
log
(
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
0
:
T
)
)
=
−
log
(
p
(
x
T
)
)
+
∑
t
=
2
T
log
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
θ
(
x
t
−
1
∣
x
t
)
)
+
∑
t
=
2
T
log
(
q
(
x
t
∣
x
0
)
q
(
x
t
−
1
∣
x
0
)
)
+
log
(
q
(
x
1
∣
x
0
)
p
θ
(
x
0
∣
x
1
)
)
=
−
log
(
p
(
x
T
)
)
+
∑
t
=
2
T
log
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
θ
(
x
t
−
1
∣
x
t
)
)
+
log
(
q
(
x
T
∣
x
0
)
q
(
x
1
∣
x
0
)
)
+
log
(
q
(
x
1
∣
x
0
)
p
θ
(
x
0
∣
x
1
)
)
=
−
log
(
p
(
x
T
)
)
+
∑
t
=
2
T
log
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
θ
(
x
t
−
1
∣
x
t
)
)
+
log
q
(
x
T
∣
x
0
)
−
log
p
θ
(
x
0
∣
x
1
)
=
log
(
q
(
x
T
∣
x
0
)
p
(
x
T
)
)
+
∑
t
=
2
T
log
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
θ
(
x
t
−
1
∣
x
t
)
)
−
log
(
p
θ
(
x
0
∣
x
1
)
)
\begin{aligned} \log \left(\frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right) & = -\log \left(p\left(\mathbf{x}_T\right)\right)+\sum_{t=2}^T \log \left(\frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}\right)+\sum_{t=2}^T \log \left(\frac{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}\right)+\log \left(\frac{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}\right) \\ &= -\log \left(p\left(\mathbf{x}_T\right)\right)+\sum_{t=2}^T \log \left(\frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}\right)+\log \left(\frac{q\left(\mathbf{x}_T \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}\right)+\log \left(\frac{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}\right) \\ &= -\log \left(p\left(\mathbf{x}_T\right)\right)+\sum_{t=2}^T \log \left(\frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}\right) + \log q\left(\mathbf{x}_T \mid \mathbf{x}_0\right)-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right) \\ &=\log \left(\frac{q\left(\mathbf{x}_T \mid \mathbf{x}_0\right)}{p\left(\mathbf{x}_T\right)}\right)+\sum_{t=2}^T \log \left(\frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}\right)-\log \left(p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)\right) \end{aligned}
log(pθ(x0:T)q(x1:T∣x0))=−log(p(xT))+t=2∑Tlog(pθ(xt−1∣xt)q(xt−1∣xt,x0))+t=2∑Tlog(q(xt−1∣x0)q(xt∣x0))+log(pθ(x0∣x1)q(x1∣x0))=−log(p(xT))+t=2∑Tlog(pθ(xt−1∣xt)q(xt−1∣xt,x0))+log(q(x1∣x0)q(xT∣x0))+log(pθ(x0∣x1)q(x1∣x0))=−log(p(xT))+t=2∑Tlog(pθ(xt−1∣xt)q(xt−1∣xt,x0))+logq(xT∣x0)−logpθ(x0∣x1)=log(p(xT)q(xT∣x0))+t=2∑Tlog(pθ(xt−1∣xt)q(xt−1∣xt,x0))−log(pθ(x0∣x1))
上式可以改写为与KL散度等价的形式:
log
(
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
0
:
T
)
)
=
log
(
q
(
x
T
∣
x
0
)
p
(
x
T
)
)
+
∑
t
=
2
T
log
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
θ
(
x
t
−
1
∣
x
t
)
)
−
log
(
p
θ
(
x
0
∣
x
1
)
)
≡
D
K
L
(
q
(
x
T
∣
x
0
)
∥
p
(
x
T
)
)
+
∑
t
=
2
T
D
K
L
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
∥
p
θ
(
x
t
−
1
∣
x
t
)
)
−
log
(
p
θ
(
x
0
∣
x
1
)
\begin{aligned} \log \left(\frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right) &= \log \left(\frac{q\left(\mathbf{x}_T \mid \mathbf{x}_0\right)}{p\left(\mathbf{x}_T\right)}\right)+\sum_{t=2}^T \log \left(\frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}\right)-\log \left(p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)\right) \\ & \equiv D_{K L}\left(q\left(\mathbf{x}_T \mid \mathbf{x}_0\right) \| p\left(\mathbf{x}_T\right)\right)+\sum_{t=2}^T D_{K L}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\right)-\log \left(p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)\right. \end{aligned}
log(pθ(x0:T)q(x1:T∣x0))=log(p(xT)q(xT∣x0))+t=2∑Tlog(pθ(xt−1∣xt)q(xt−1∣xt,x0))−log(pθ(x0∣x1))≡DKL(q(xT∣x0)∥p(xT))+t=2∑TDKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt))−log(pθ(x0∣x1)
上式的第一项
D
K
L
(
q
(
x
T
∣
x
0
)
∥
p
(
x
T
)
)
D_{K L}\left(q\left(\mathbf{x}_T \mid \mathbf{x}_0\right) \| p\left(x_T\right)\right)
DKL(q(xT∣x0)∥p(xT))是常数,并且几乎为零,我们将注意力放在后面带参数的两项,我们先具体地求解出
q
(
x
t
−
1
∣
x
t
,
x
0
)
q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)
q(xt−1∣xt,x0)(三个高斯分布相乘):
q
(
x
t
−
1
∣
x
t
,
x
0
)
=
q
(
x
t
∣
x
t
−
1
,
x
0
)
q
(
x
t
−
1
∣
x
0
)
q
(
x
t
∣
x
0
)
∝
exp
(
−
1
2
(
(
x
t
−
α
t
x
t
−
1
)
2
β
t
+
(
x
t
−
1
−
α
ˉ
t
−
1
x
0
)
2
1
−
α
ˉ
t
−
1
−
(
x
t
−
α
ˉ
t
x
0
)
2
1
−
α
ˉ
t
)
)
=
exp
(
−
1
2
(
x
t
2
−
2
α
t
x
t
x
t
−
1
+
α
t
x
t
−
1
2
β
t
+
x
t
−
1
2
−
2
α
ˉ
t
−
1
x
0
x
t
−
1
+
α
ˉ
t
−
1
x
0
2
1
−
α
ˉ
t
−
1
−
(
x
t
−
α
ˉ
t
x
0
)
2
1
−
α
ˉ
t
)
)
=
exp
(
−
1
2
(
(
α
t
β
t
+
1
1
−
α
ˉ
t
−
1
)
x
t
−
1
2
−
(
2
α
t
β
t
x
t
+
2
α
ˉ
t
−
1
1
−
α
ˉ
t
−
1
x
0
)
x
t
−
1
+
C
(
x
t
,
x
0
)
)
)
\begin{aligned} q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)&=q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}, \mathbf{x}_0\right) \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)} \\ & \propto \exp \left(-\frac{1}{2}\left(\frac{\left(\mathbf{x}_t-\sqrt{\alpha_t} \mathbf{x}_{t-1}\right)^2}{\beta_t}+\frac{{\left(\mathbf{x}_{t-1}- \sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_0) \right.}^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t} \mathbf{x}_0\right)^2}{1-\bar{\alpha}_t}\right)\right) \\ & =\exp \left(-\frac{1}{2}\left(\frac{\mathbf{x}_t^2-2 \sqrt{\alpha_t} \mathbf{x}_t \mathbf{x}_{t-1}+\alpha_t \mathbf{x}_{t-1}^2}{\beta_t}+\frac{\mathbf{x}_{t-1}^2-2 \sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_0 \mathbf{x}_{t-1}+\bar{\alpha}_{t-1} \mathbf{x}_0^2}{1-\bar{\alpha}_{t-1}}-\frac{\left(\mathbf{x}_t-\sqrt{\bar{\alpha}_t} \mathbf{x}_0\right)^2}{1-\bar{\alpha}_t}\right)\right) \\ & =\exp \left(-\frac{1}{2}\left(\left(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}}\right) \mathbf{x}_{t-1}^2-\left(\frac{2 \sqrt{\alpha_t}}{\beta_t} \mathbf{x}_t+\frac{2 \sqrt{\bar{\alpha}_{t-1}}}{1-\bar{\alpha}_{t-1}} \mathbf{x}_0\right) \mathbf{x}_{t-1}+C\left(\mathbf{x}_t, \mathbf{x}_0\right)\right)\right) \\ \end{aligned}
q(xt−1∣xt,x0)=q(xt∣xt−1,x0)q(xt∣x0)q(xt−1∣x0)∝exp(−21(βt(xt−αtxt−1)2+1−αˉt−1(xt−1−αˉt−1x0)2−1−αˉt(xt−αˉtx0)2))=exp(−21(βtxt2−2αtxtxt−1+αtxt−12+1−αˉt−1xt−12−2αˉt−1x0xt−1+αˉt−1x02−1−αˉt(xt−αˉtx0)2))=exp(−21((βtαt+1−αˉt−11)xt−12−(βt2αtxt+1−αˉt−12αˉt−1x0)xt−1+C(xt,x0)))
C
(
x
t
,
x
0
)
C\left(\mathbf{x}_t, \mathbf{x}_0\right)
C(xt,x0)是与
x
t
−
1
\mathbf{x}_{t-1}
xt−1无关的常数项,这一项太长了,这里就不写了,感兴趣的话可以自行展开计算一下,现在我们要找的是如何用
x
t
,
x
0
\mathbf{x}_t, \mathbf{x}_0
xt,x0来表示高斯分布
x
t
−
1
\mathbf{x}_{t-1}
xt−1的均值,所以我们希望能够将上式写成
(
x
t
−
1
−
M
)
2
(\mathbf{x}_{t-1}-M)^2
(xt−1−M)2的形式,同时
M
M
M是用
x
t
,
x
0
\mathbf{x}_t, \mathbf{x}_0
xt,x0表示的,下面直接给出结果
令
β
~
t
=
1
/
(
α
t
β
t
+
1
1
−
α
ˉ
t
−
1
)
=
1
/
(
α
t
−
α
ˉ
t
+
β
t
β
t
(
1
−
α
ˉ
t
−
1
)
)
=
1
−
α
ˉ
t
−
1
1
−
α
ˉ
t
⋅
β
t
\tilde{\beta}_t=1 /\left(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}}\right)=1 /\left(\frac{\alpha_t-\bar{\alpha}_t+\beta_t}{\beta_t\left(1-\bar{\alpha}_{t-1}\right)}\right)=\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \cdot \beta_t
β~t=1/(βtαt+1−αˉt−11)=1/(βt(1−αˉt−1)αt−αˉt+βt)=1−αˉt1−αˉt−1⋅βt
将第一项
x
t
−
1
2
\mathbf{x}_{t-1}^2
xt−12前面的系数变为1,第二项相应地也除以相同的值:
(
α
t
β
t
x
t
+
α
ˉ
t
−
1
1
−
α
ˉ
t
−
1
x
0
)
/
(
α
t
β
t
+
1
1
−
α
ˉ
t
−
1
)
=
(
α
t
β
t
x
t
+
α
ˉ
t
−
1
1
−
α
ˉ
t
−
1
x
0
)
1
−
α
ˉ
t
−
1
1
−
α
ˉ
t
⋅
β
t
=
α
t
(
1
−
α
ˉ
t
−
1
)
1
−
α
ˉ
t
x
t
+
α
ˉ
t
−
1
β
t
1
−
α
ˉ
t
x
0
\begin{aligned} & \left(\frac{\sqrt{\alpha_t}}{\beta_t} \mathbf{x}_t+\frac{\sqrt{\bar{\alpha}_{t-1}}}{1-\bar{\alpha}_{t-1}} \mathbf{x}_0\right) /\left(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}}\right) \\ =&\left(\frac{\sqrt{\alpha_t}}{\beta_t} \mathbf{x}_t+\frac{\sqrt{\bar{\alpha}_{t-1}}}{1-\bar{\alpha}_{t-1}} \mathbf{x}_0\right) \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \cdot \beta_t \\ =& \frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} \mathbf{x}_t+\frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1-\bar{\alpha}_t} \mathbf{x}_0 \end{aligned}
==(βtαtxt+1−αˉt−1αˉt−1x0)/(βtαt+1−αˉt−11)(βtαtxt+1−αˉt−1αˉt−1x0)1−αˉt1−αˉt−1⋅βt1−αˉtαt(1−αˉt−1)xt+1−αˉtαˉt−1βtx0
实际上
[
x
t
−
1
−
(
α
t
(
1
−
α
ˉ
t
−
1
)
1
−
α
ˉ
t
x
t
+
α
ˉ
t
−
1
β
t
1
−
α
ˉ
t
x
0
)
]
2
\left[\mathbf{x}_{t-1}-(\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} \mathbf{x}_t+\frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1-\bar{\alpha}_t} \mathbf{x}_0)\right]^2
[xt−1−(1−αˉtαt(1−αˉt−1)xt+1−αˉtαˉt−1βtx0)]2的式子展开与上式(每一项都除以
(
α
t
β
t
+
1
1
−
α
ˉ
t
−
1
(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}}
(βtαt+1−αˉt−11)一致。好了,现在我们求解得到了
q
(
x
t
−
1
∣
x
t
,
x
0
)
q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)
q(xt−1∣xt,x0)的分布:
q
(
x
t
−
1
∣
x
t
,
x
0
)
=
N
(
x
t
−
1
;
μ
~
(
x
t
,
x
0
)
,
β
~
t
I
)
w
h
e
r
e
μ
~
(
x
t
,
x
0
)
=
α
t
(
1
−
α
ˉ
t
−
1
)
1
−
α
ˉ
t
x
t
+
α
ˉ
t
−
1
β
t
1
−
α
ˉ
t
x
0
σ
~
=
β
~
t
\begin{aligned} q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) &=\mathcal{N}\left(\mathbf{x}_{t-1} ; \tilde{\boldsymbol{\mu}}\left(\mathbf{x}_t, \mathbf{x}_0\right), \tilde{\beta}_t \mathbf{I}\right) \\ where \quad & \tilde{\boldsymbol{\mu}}\left(\mathbf{x}_t, \mathbf{x}_0\right)=\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} \mathbf{x}_t+\frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1-\bar{\alpha}_t} \mathbf{x}_0 \\ & \tilde{\boldsymbol{\sigma}}=\tilde{\beta}_t \end{aligned}
q(xt−1∣xt,x0)where=N(xt−1;μ~(xt,x0),β~tI)μ~(xt,x0)=1−αˉtαt(1−αˉt−1)xt+1−αˉtαˉt−1βtx0σ~=β~t
好像在这里看来仍旧没有什么作用,损失函数里面还是有很多的内容,现在我们先来看一下损失函数变成了什么,只看第二项
∑
t
=
2
T
D
K
L
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
∥
p
θ
(
x
t
−
1
∣
x
t
)
)
\sum_{t=2}^T D_{K L}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\right)
∑t=2TDKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt)):
系统提示我博客字数太多了,要我分成两篇来写,请接着看下一节:diffusion models 扩散模型公式推导,原理分析与代码(二)🧐🧐🧐🧐