文章提出了一个基于随机微分方程的生成模型。之前的Denoising Score Matching with Annealed Langevin Sampling(SMLD)和 Denoising Diffusion Probabilistic Models(DDPM)方法都可以合并到该框架中。
建立一个连续时间索引的扩散过程
{
x
(
t
)
}
t
=
0
T
,
t
∈
[
0
,
T
]
\{\mathbf{x}(t)\}_{t=0}^T, t \in [0, T]
{x(t)}t=0T,t∈[0,T],其满足
x
(
0
)
∼
p
0
\mathbf{x}(0)\sim p_0
x(0)∼p0是需要学习的目标数据分布,
x
(
T
)
∼
p
T
\mathbf{x}(T)\sim p_T
x(T)∼pT是便于采样的先验分布。这个扩散过程可以用下面的随机微分方程(SDE)的解表示:
d
x
=
f
(
x
,
t
)
d
t
+
g
(
t
)
d
w
(5)
\mathrm{d}\mathbf{x} = f(\mathbf{x}, t)\mathrm{d}t + g(t)\mathrm{d}\mathbf{w} \tag{5}
dx=f(x,t)dt+g(t)dw(5)
w
\mathbf{w}
w是标准Wiener过程,
f
(
⋅
,
t
)
f(\cdot, t)
f(⋅,t)是称为漂移系数(drift coefficient)的向量函数,
g
(
t
)
g(t)
g(t)是称为扩散系数(diffusion coefficient)的标量函数。
通过从
x
(
T
)
∼
p
T
\mathbf{x}(T)\sim p_T
x(T)∼pT采样,并逆转上面的过程,我们可以得到
x
(
0
)
∼
p
0
\mathbf{x}(0)\sim p_0
x(0)∼p0,从而得到目标数据分布的样本。已有工作证明上面扩散过程的逆仍然是一个扩散过程,不过时间从T到0:
d
x
=
[
f
(
x
,
t
)
−
g
(
t
)
2
∇
x
log
p
t
(
x
)
]
d
t
+
g
(
t
)
d
w
‾
(6)
\mathrm{d}\mathbf{x} = [f(\mathbf{x}, t) - g(t)^2 \nabla_\mathbf{x}\log p_t(\mathbf{x})]\mathrm{d}t + g(t)\mathrm{d}\overline{\mathbf{w}} \tag{6}
dx=[f(x,t)−g(t)2∇xlogpt(x)]dt+g(t)dw(6)
w
‾
\overline{\mathbf{w}}
w是时间从T到0的标准Wiener过程。
∇
x
log
p
t
(
x
)
\nabla_\mathbf{x}\log p_t(\mathbf{x})
∇xlogpt(x)被称为分数score。如果可以得到分数
∇
x
log
p
t
(
x
)
\nabla_\mathbf{x}\log p_t(\mathbf{x})
∇xlogpt(x),那么就可以通过逆过程采样
p
0
p_0
p0的样本。
为了得到 x ( 0 ) ∼ p 0 \mathbf{x}(0)\sim p_0 x(0)∼p0,需要先求分数 ∇ x log p t ( x ) \nabla_\mathbf{x}\log p_t(\mathbf{x}) ∇xlogpt(x),再根据公式(6)的逆扩散过程求解 x ( T ) \mathbf{x}(T) x(T)。
分数估计
∇
x
log
p
t
(
x
)
\nabla_\mathbf{x}\log p_t(\mathbf{x})
∇xlogpt(x)可以通过训练一个基于分数的模型来估计。
当
f
(
⋅
,
t
)
f(\cdot, t)
f(⋅,t)是仿射变换时,变换核(transition kernel)
p
0
t
(
x
(
t
)
∣
x
(
0
)
)
p_{0t}(\mathbf x(t) | \mathbf x(0))
p0t(x(t)∣x(0))是高斯分布,其均值和方差都可以求出来,所以也可以求得导数的表达式。如果
f
(
⋅
,
t
)
f(\cdot, t)
f(⋅,t)是其他变换,可以通过求Kolmogorov’s forward equation得到变换核,也可以通过sliced score matching算法从SDE采样得到版本,再绕过显式计算
∇
x
(
t
)
log
p
0
t
(
x
(
t
)
∣
x
(
0
)
)
\nabla_\mathbf{x}(t)\log p_{0t}(\mathbf x(t) | \mathbf x(0))
∇x(t)logp0t(x(t)∣x(0))。
逆SDE求解
可以用通用的SDE求解算法直接求逆时间SDE。但因为我们有score模型,所以我们考虑更好的方法,也就是利用score-based MCMC方法。
作者提出Predictor-Corrector (PC) samplers。在每一个时间步,首先用SDE求解器估计下一个时间步的样本(predictor),然后再用score-based MCMC方法修正估计样本的边际分布(corrector)。
概率流
对于所有的扩散过程,存在一个确定的过程,其轨迹和扩散过程有相同的边际概率密度(marginal probability densities)
{
p
(
x
)
}
t
=
0
T
\{p(\mathbf{x})\}_{t=0}^T
{p(x)}t=0T。公式(5)对应的确定过程满足下面的常微分方程(ODE):
d
x
=
[
f
(
x
,
t
)
−
1
2
g
(
t
)
2
∇
x
log
p
t
(
x
)
]
d
t
\mathrm{d}\mathbf{x} = [f(\mathbf{x}, t) - \frac{1}{2}g(t)^2 \nabla_\mathbf{x}\log p_t(\mathbf{x})]\mathrm{d}t
dx=[f(x,t)−21g(t)2∇xlogpt(x)]dt作者将上面的ODE称为概率流ODE(probability flow ODE)。
DDPM不能直接计算likelihood,只能用ELBO代替likelihood。而扩散过程的likelihood可以通过转换为概率流ODE计算。因为概率流ODE算是neural ODE的一个特例,neural ODE的论文中已经证明其可以计算flow的概率,所以通过instantaneous change of variables formula可以计算概率流ODE的概率。
另外,因为公式(5)没有可学习的参数,所以如果有完美估计的score,那么通过概率流ODE可以得到数据和隐含表示的一对一映射关系。毕竟概率流ODE类似normalize flow是一个确定的过程,而不像DDPM等方法一样,在不停的采样。
可控生成
可控生成可以通过求解条件逆时间SDE得到
d
x
=
{
f
(
x
,
t
)
−
g
(
t
)
2
[
∇
x
log
p
t
(
x
)
+
∇
x
log
p
t
(
y
∣
x
)
]
}
d
t
+
g
(
t
)
d
w
‾
\mathrm{d}\mathbf{x} = \{f(\mathbf{x}, t) - g(t)^2 [\nabla_\mathbf{x}\log p_t(\mathbf{x}) + \nabla_\mathbf{x}\log p_t(\mathbf{y|x})]\}\mathrm{d}t + g(t)\mathrm{d}\overline{\mathbf{w}}
dx={f(x,t)−g(t)2[∇xlogpt(x)+∇xlogpt(y∣x)]}dt+g(t)dw
VE, VP SDES
SMLD 和 DDPM 中使用的噪声扰动可以看作是两个不同 SDE 的离散化。
SMLD对应的SDE如下,其随着t的增大方差变得越来越大,所以被称为Variance Exploding (VE) SDE。
d
x
=
d
[
σ
2
t
]
d
t
d
w
(9)
\mathrm{d}\mathbf{x} = \sqrt{\frac{\mathrm{d [\sigma^2t]}}{\mathrm{d} t}}\mathrm{d}\mathbf{w} \tag{9}
dx=dtd[σ2t]dw(9)
DDPM对应的SDE如下,其方差是有界的,并且上界是时刻0的方差。当时刻0的方差为
I
\mathbf I
I时,SDE在任何时刻的方差都是
I
\mathbf I
I,所以被称为Variance Preserving (VP) SDE。
d
x
=
−
1
2
β
(
t
)
x
d
t
+
β
(
t
)
d
w
(11)
\mathrm{d}\mathbf{x} = -\frac{1}{2}\beta(t)\mathbf x \mathrm{d}t + \sqrt{\beta(t)}\mathrm{d}\mathbf{w} \tag{11}
dx=−21β(t)xdt+β(t)dw(11)
作者提出了一种新的SDE称为sub-VP SDE,如下所示。在中间时刻,该SDE的方差总是以VP SDE的方差为上界。实验表明,sub-VP SDE的likelihood很低。
d
x
=
−
1
2
β
(
t
)
x
d
t
+
β
(
t
)
(
1
−
e
−
2
∫
0
t
β
(
s
)
d
s
)
d
w
(12)
\mathrm{d}\mathbf{x} = -\frac{1}{2}\beta(t)\mathbf x \mathrm{d}t + \sqrt{\beta(t)(1-e^{-2\int_0^t\beta(s)\mathrm{d}s})}\mathrm{d}\mathbf{w} \tag{12}
dx=−21β(t)xdt+β(t)(1−e−2∫0tβ(s)ds)dw(12)
这三种方法的变换核都是高斯分布,可以以封闭形式计算,所以分数模型可以高效的训练。