1、前言
本篇文章,将从SDE(随机微分方程)视角,去解释前两个模型(DDPM、NCSN),将它们统一起来。并从SDE视角上,提供一个具有更高似然的新模型。与此同时,提供各种新式采样方法。
参考论文:
①Score-Based Generative Modeling through Stochastic Differential Equations (arxiv.org)
②Tutorial on Diffusion Models for Imaging and Vision (arxiv.org)(这篇论文有些错误,请注意甄别)
2、DDPM、NCSN回顾及统一
2.1、NCSN
NCSN,我沿用我之前那篇文章的符号噪声条件分数网络
其损失函数为
L
=
1
S
∑
i
=
1
S
λ
i
1
2
E
P
d
a
t
a
(
x
)
,
x
~
∼
N
(
x
,
σ
i
2
I
)
[
∣
∣
s
θ
(
x
+
σ
i
z
,
σ
i
)
+
x
~
i
−
x
σ
i
2
∣
∣
2
2
]
L=\frac{1}{S}\sum\limits_{i=1}^S\lambda_i\frac{1}{2}\mathbb{E}_{P_{data}(x),\tilde x\sim N(x,\sigma_i^2I)}\left[||s_\theta(x+\sigma_i z,\sigma_i)+\frac{\tilde x_i -x}{\sigma_i^2}||_2^2\right]
L=S1i=1∑Sλi21EPdata(x),x~∼N(x,σi2I)[∣∣sθ(x+σiz,σi)+σi2x~i−x∣∣22]
其中S表示噪声量级的个数,
λ
i
\lambda_i
λi是缩放系数,一般取
σ
i
2
\sigma_i^2
σi2,而
x
~
i
−
x
σ
i
2
\frac{\tilde x_i -x}{\sigma_i^2}
σi2x~i−x由
q
σ
i
(
x
~
∣
x
)
q_{\sigma_i}(\tilde x|x)
qσi(x~∣x)变化而来(
σ
i
\sigma_i
σi表示的是加噪强度为
σ
i
\sigma_i
σi)。Ps:论文将
q
σ
(
x
~
∣
x
)
q_{\sigma}(\tilde x|x)
qσ(x~∣x)称为扰动核
同时在NCSN里面,我们曾得到
x
~
i
−
x
σ
i
2
=
z
σ
i
\frac{\tilde x_i -x}{\sigma_i^2}=\frac{z}{\sigma_i}
σi2x~i−x=σiz,于是便有
L
=
1
S
∑
i
=
1
S
λ
i
1
2
E
P
d
a
t
a
(
x
)
,
z
∼
N
(
0
,
I
)
[
∣
∣
s
θ
(
x
+
σ
i
z
,
σ
i
)
+
z
σ
i
∣
∣
2
2
]
(1)
L=\frac{1}{S}\sum\limits_{i=1}^S\lambda_i\frac{1}{2}\mathbb{E}_{P_{data}(x),z\sim N(0,I)}\left[||s_\theta(x+\sigma_i z,\sigma_i)+\frac{z}{\sigma_i}||_2^2\right]\tag{1}
L=S1i=1∑Sλi21EPdata(x),z∼N(0,I)[∣∣sθ(x+σiz,σi)+σiz∣∣22](1)
此时
s
θ
s_{\theta}
sθ预测的其实就是
−
z
σ
i
-\frac{z}{\sigma_i}
−σiz
Eq.(1)训练完成之后,就可以进行采样生成(郎之万动力学采样):
x
t
+
1
=
x
t
+
α
∇
x
log
P
(
x
t
)
+
2
α
z
t
x_{t+1}=x_t+\alpha \nabla_x\log P(x_t)+\sqrt{2\alpha}z_t
xt+1=xt+α∇xlogP(xt)+2αzt
其中
∇
x
log
P
(
x
t
)
≈
s
θ
\nabla_x\log P(x_t)\approx s_\theta
∇xlogP(xt)≈sθ,也就是采样生成的时候用神经网络代替里面的分数函数
∇
x
log
P
(
x
t
)
\nabla_x\log P(x_t)
∇xlogP(xt)
值得注意的是,其实叫NCSN不准确,NCSN对应的是用 q ( x ~ ∣ x ) q(\tilde x|x) q(x~∣x)加噪,来去避免 P ( x ) P(x) P(x)无法计算的问题。其实还存在其他的方法。如果单纯的从模型上来看,它们都是分数的,并且采样方法都是郎之万动力采样,所以它们有一个共同的名字,称为SMLD(Score-Matching Langevin Dynamics)。只是说NCSN更为常用。本篇文章,我用NCSN指代SMLD了(毕竟我上篇文章只讲了NCSN)
2.2、DDPM
加噪过程表示为 q ( x i ∣ x i − 1 ) = N ( x i ∣ 1 − β i x i − 1 , β i I ) q(x_i|x_{i-1})=\mathcal{N}(x_i|\sqrt{1-\beta_i}x_{i-1},\beta_iI) q(xi∣xi−1)=N(xi∣1−βixi−1,βiI),跳步加噪表示为 q ( x i ∣ x 0 ) = N ( x i ∣ α i x 0 , ( 1 − α i ) I ) q(x_i|x_0)=\mathcal{N}(x_i|\sqrt{\alpha_i}x_0,(1-\alpha_i)I) q(xi∣x0)=N(xi∣αix0,(1−αi)I)
请注意,在DDPM里面,其实应该是 q ( x i ∣ x 0 ) = N ( x i ∣ α ˉ i x 0 , ( 1 − α ˉ i ) I ) q(x_i|x_0)=\mathcal{N}(x_i|\sqrt{\bar\alpha_i}x_0,(1-\bar\alpha_i)I) q(xi∣x0)=N(xi∣αˉix0,(1−αˉi)I)。该论文把 α i \alpha_i αi等价了DDPM里面的 α ˉ i \bar \alpha_i αˉi,请不要被误导
去噪过程表示为
P
(
x
i
−
1
∣
x
i
)
P(x_{i-1}|x_i)
P(xi−1∣xi)。重参数化得到采样生成步骤(论文将该方法称为祖先采样)
x
i
−
1
=
1
1
−
β
i
(
x
i
−
β
i
1
−
α
i
ϵ
θ
(
x
i
,
i
)
)
+
σ
i
z
(2)
x_{i-1}=\frac{1}{\sqrt{1-\beta_i}}\left(x_i-\frac{\beta_i}{\sqrt{1-\alpha_i}}\epsilon_{\theta}(x_i,i)\right)+\sigma_iz\tag{2}
xi−1=1−βi1(xi−1−αiβiϵθ(xi,i))+σiz(2)
ϵ
θ
(
x
i
,
i
)
\epsilon_\theta(x_i,i)
ϵθ(xi,i)是用神经网络预测的,损失函数为
L
=
∣
∣
ϵ
i
−
ϵ
θ
(
x
i
,
i
)
∣
∣
2
L=||\epsilon_i-\epsilon_\theta(x_i,i)||^2
L=∣∣ϵi−ϵθ(xi,i)∣∣2
2.3、两者统一
此时我们注意到Eq.(2)里面的,
1
−
α
i
\sqrt{1-\alpha_i}
1−αi是标准差,相当于NCSN里面的
σ
i
\sigma_i
σi,而
ϵ
\epsilon
ϵ相当于NCSN里面的z,于是,我们可以令
−
ϵ
θ
(
x
i
,
i
)
1
−
α
i
=
s
θ
(
x
i
,
i
)
-\frac{\epsilon_\theta(x_i,i)}{\sqrt{1-\alpha_i}}=s_\theta(x_i,i)
−1−αiϵθ(xi,i)=sθ(xi,i),则Eq.(2)表示为
x
i
−
1
=
1
1
−
β
i
(
x
i
+
β
i
s
θ
(
x
i
,
i
)
)
+
σ
i
z
x_{i-1}=\frac{1}{\sqrt{1-\beta_i}}\left(x_i+\beta_is_\theta(x_i,i)\right)+\sigma_iz
xi−1=1−βi1(xi+βisθ(xi,i))+σiz
里面的
σ
i
\sigma_i
σi在这里取为
β
i
\sqrt{\beta_i}
βi,于是得到
x
i
−
1
=
1
1
−
β
i
(
x
i
+
β
i
s
θ
(
x
i
,
i
)
)
+
β
i
z
x_{i-1}=\frac{1}{\sqrt{1-\beta_i}}\left(x_i+\beta_is_\theta(x_i,i)\right)+\sqrt{\beta_i}z
xi−1=1−βi1(xi+βisθ(xi,i))+βiz
我们不难看到里面其实只有
s
θ
s_\theta
sθ是未知的,因此我们完全可以直接去预测一个
s
θ
s_\theta
sθ就可以了。所以损失函数就可以写成(可以根据KL散度导出,或者我们直观理解都可以了)
L
=
∣
∣
s
θ
(
x
i
,
i
)
+
ϵ
i
1
−
α
i
∣
∣
2
L=||s_\theta(x_i,i)+\frac{\epsilon_i}{\sqrt{1-\alpha_i}}||^2
L=∣∣sθ(xi,i)+1−αiϵi∣∣2
我们知道NCSN里面使用不同尺度的噪声给原始数据加噪;而DDPM里面也是一个不断的加噪过程。我们刚刚又进行了损失函数的转化,此时我们不难发现,DDPM和NCSN可以统一称为分数模型(
s
θ
s_{\theta}
sθ预测的是分数函数)
q σ ( x ~ ∣ x ) q_{\sigma}(\tilde x|x) qσ(x~∣x)表示从初始图像 x x x加噪到 x ~ \tilde x x~,可以写成与DDPM的形式 q ( x i ∣ x 0 ) q(x_i|x_0) q(xi∣x0),也就是从 x 0 x_0 x0加噪到 x i x_i xi
论文将 q ( x i ∣ x 0 ) q(x_i|x_0) q(xi∣x0)统称为扰动核
3、引入
3.1、微分方程
回忆一下导数
f
′
(
x
)
=
lim
Δ
x
→
0
f
(
x
+
Δ
x
)
−
f
(
x
)
Δ
x
=
d
f
(
x
)
d
x
f'(x)=\lim\limits_{\Delta x \to 0}\frac{f(x+\Delta x)-f(x)}{\Delta x}=\frac{df(x)}{dx}
f′(x)=Δx→0limΔxf(x+Δx)−f(x)=dxdf(x)
假设我们现在有这么一个函数
x
(
t
+
Δ
t
)
=
(
1
−
β
Δ
t
2
)
x
(
t
)
x(t+\Delta t)=(1-\frac{\beta\Delta t}{2})x(t)
x(t+Δt)=(1−2βΔt)x(t)
移项整理得
x
(
t
+
Δ
t
)
−
x
(
t
)
Δ
t
=
−
β
2
x
(
t
)
⟹
Δ
t
→
0
d
x
(
t
)
d
t
=
−
β
2
x
(
t
)
\frac{x(t+\Delta t)-x(t)}{\Delta t}=-\frac{\beta}{2}x(t) \overset{\Delta t \to0}{\Longrightarrow} \frac{dx(t)}{dt}=-\frac{\beta}{2}x(t)
Δtx(t+Δt)−x(t)=−2βx(t)⟹Δt→0dtdx(t)=−2βx(t)
我们称上式为常微分方程(ODE)
微分方程定义:含有未知函数的导数,如 d y d x = 2 x \frac{dy}{dx}=2x dxdy=2x的方程是微分方程。 一般的凡是表示未知函数、未知函数的导数与自变量之间的关系的方程,叫做微分方程。未知函数是一元函数的,叫常微分方程;未知函数是多元函数的叫做偏微分方程。微分方程有时也简称方程
3.2、随机微分方程(SDE)
伊藤SDE表达式为
d
x
=
f
(
x
,
t
)
d
t
+
g
(
t
)
d
w
(3)
\mathbb{dx=f(x,t)dt}+g(t)\mathbb{dw}\tag{3}
dx=f(x,t)dt+g(t)dw(3)
f
(
⋅
,
t
)
:
R
d
→
R
d
\mathbb{f(\cdot,t)}:\mathbb{R}^d\to\mathbb{R}^d
f(⋅,t):Rd→Rd是一个向量函数,一般称为
x
(
t
)
x(t)
x(t)的漂移系数;
g
(
⋅
)
:
R
→
R
g(\cdot):\mathbb{R\to R}
g(⋅):R→R是一个标量函数,一般称为
x
(
t
)
x(t)
x(t)的扩散系数。此处为了表达的简便,暂且把扩散系数表示成一维的标量(实际上还存在
d
×
d
d\times d
d×d维的情况)
而
w
\mathbb{w}
w是标准维纳过程(布朗运动),实际上其差分就是一个方差随着时间变化而变化的高斯分布
d
w
=
w
(
t
+
Δ
t
)
−
w
(
t
)
=
Δ
t
z
(
t
)
\mathbb{dw=}\mathbb{w(t+\Delta t)-w(t)}=\sqrt{\Delta t}z(t)
dw=w(t+Δt)−w(t)=Δtz(t)
其中
z
(
t
)
=
N
(
0
,
1
)
z(t) = \mathcal{N}(0,1)
z(t)=N(0,1),所以
d
w
=
Δ
t
z
(
t
)
=
N
(
0
,
Δ
t
)
\mathbb{dw}=\sqrt{\Delta t}z(t)=N(0,\Delta t)
dw=Δtz(t)=N(0,Δt)
Eq.(3)的解是
x
(
t
)
x(t)
x(t),也就是一个随着时间变化而变化的x。
4、分数模型与SDE的关系
4.1、加噪过程连续化
①DDPM:
DDPM是一个离散的马尔可夫链,加噪过程是一个线性高斯。
离散意为着时刻t是有限的,而SDE,就是将时刻扩展至无限
我们可以证明,DDPM的加噪过程,可以用伊藤SDE去表示。为什么要用SDE去表示?当我们用SDE表示之后,我们就可以得到连续时间的 x ( t ) x(t) x(t),也就是随着时间变化而变化的x(对应加噪过程)
下面我们将DDPM的加噪过程转化成伊藤SDE
回忆一下DDPM的加噪表达式
q
(
x
i
∣
x
i
−
1
)
q(x_{i}|x_{i-1})
q(xi∣xi−1)
x
i
=
1
−
β
i
x
i
−
1
+
β
i
z
i
−
1
(4)
x_i=\sqrt{1-\beta_i} x_{i-1}+\sqrt{\beta_i}z_{i-1}\tag{4}
xi=1−βixi−1+βizi−1(4)
其中 z i − 1 z_{i-1} zi−1服从标准正太分布
Ps:在我之前的文章不是 z i − 1 z_{i-1} zi−1,而是用 z i z_i zi( z i z_i zi无论在什么时刻都是标准正太,所以其实写成什么区别不大)。在这里我用论文的表示方法
对于离散的DDPM,我们有 { x i } i = 1 N \{x_i\}_{i=1}^N {xi}i=1N(也就是有N个时刻,在DDPM里面用T表示有T个时刻,这篇论文用N表示有N个时刻)
为了简便,构造一个
{
β
ˉ
i
=
N
β
i
}
i
=
1
N
\{\bar\beta_i=N\beta_i\}_{i=1}^N
{βˉi=Nβi}i=1N,所以有
β
i
=
1
N
β
ˉ
i
\beta_i=\frac{1}{N}\bar\beta_i
βi=N1βˉi,代入Eq.(4)
x
i
=
1
−
1
N
β
ˉ
i
x
i
−
1
+
1
N
β
ˉ
i
z
i
−
1
(5)
x_i=\sqrt{1-\frac{1}{N}\bar\beta_i} x_{i-1}+\sqrt{\frac{1}{N}\bar\beta_i}z_{i-1}\tag{5}
xi=1−N1βˉixi−1+N1βˉizi−1(5)
当
N
→
∞
N\to \infty
N→∞,构造连续的随机过程
{
x
(
t
)
}
t
=
0
1
\{x(t)\}_{t=0}^1
{x(t)}t=01,其中
t
∈
[
0
,
1
]
t\in[0,1]
t∈[0,1]。同样的
{
β
ˉ
i
}
i
=
1
N
\{\bar\beta_i\}_{i=1}^N
{βˉi}i=1N也变成了
t
∈
[
0
,
1
]
t\in[0,1]
t∈[0,1]的索引函数
β
(
t
)
\beta(t)
β(t),其他也是同理,
1
N
=
Δ
t
\frac{1}{N}=\Delta t
N1=Δt,令
x
i
=
x
(
i
N
)
=
x
(
t
+
Δ
t
)
x_i=x(\frac{i}{N})=x(t+\Delta t)
xi=x(Ni)=x(t+Δt),
z
i
=
z
(
1
N
)
=
z
(
t
+
Δ
t
)
z_i=z(\frac{1}{N})=z(t+\Delta t)
zi=z(N1)=z(t+Δt),
于是Eq.(5)可表示为
x
(
t
+
Δ
t
)
=
1
−
β
(
t
+
Δ
t
)
Δ
t
x
(
t
)
+
β
(
t
+
Δ
t
)
Δ
t
z
(
t
)
(6)
x(t+\Delta t)=\sqrt{1-\beta(t+\Delta t)\Delta t}x(t)+\sqrt{\beta(t+\Delta t)\Delta t}z(t)\tag{6}
x(t+Δt)=1−β(t+Δt)Δtx(t)+β(t+Δt)Δtz(t)(6)
注意到
1
−
β
(
t
+
Δ
t
)
Δ
t
\sqrt{1-\beta(t+\Delta t)\Delta t}
1−β(t+Δt)Δt,当
β
(
t
+
Δ
t
)
Δ
t
→
0
\beta(t+\Delta t)\Delta t\to 0
β(t+Δt)Δt→0
我们使用一阶泰勒展开 1 − β ( t + Δ t ) Δ t ≈ 1 − 1 2 β ( t + Δ t ) Δ t \sqrt{1-\beta(t+\Delta t)\Delta t}\approx 1-\frac{1}{2}\beta(t+\Delta t)\Delta t 1−β(t+Δt)Δt≈1−21β(t+Δt)Δt
代入Eq.(6),当
Δ
t
→
0
\Delta t \to 0
Δt→0,进一步有
x
(
t
+
Δ
t
)
=
(
1
−
1
2
β
(
t
+
Δ
t
)
Δ
t
)
x
(
t
)
+
β
(
t
+
Δ
t
)
Δ
t
z
(
t
)
≈
x
(
t
)
−
1
2
β
(
t
)
Δ
t
x
(
t
)
+
β
(
t
)
Δ
t
z
(
t
)
=
x
(
t
)
−
1
2
β
(
t
)
x
(
t
)
Δ
t
+
β
(
t
)
Δ
t
z
(
t
)
\begin{aligned}x(t+\Delta t)=&\left(1-\frac{1}{2}\beta(t+\Delta t)\Delta t\right)x(t)+\sqrt{\beta(t+\Delta t)\Delta t}z(t)\\\approx&x(t)-\frac{1}{2}\beta(t)\Delta tx(t)+\sqrt{\beta(t)\Delta t}z(t)\\=&x(t)-\frac{1}{2}\beta(t)x(t)\Delta t+\sqrt{\beta(t)}\sqrt{\Delta t}z(t)\end{aligned}\nonumber
x(t+Δt)=≈=(1−21β(t+Δt)Δt)x(t)+β(t+Δt)Δtz(t)x(t)−21β(t)Δtx(t)+β(t)Δtz(t)x(t)−21β(t)x(t)Δt+β(t)Δtz(t)
将
x
(
t
)
x(t)
x(t)移至左侧
x
(
t
+
Δ
t
)
−
x
(
t
)
=
−
1
2
β
(
t
)
x
(
t
)
Δ
t
+
β
(
t
)
Δ
t
z
(
t
)
⟹
d
x
=
−
1
2
β
(
t
)
x
(
t
)
⏟
f
(
x
,
t
)
d
t
+
β
(
t
)
⏟
g
(
t
)
d
w
\begin{aligned}&x(t+\Delta t)-x(t)=-\frac{1}{2}\beta(t)x(t)\Delta t+\sqrt{\beta(t)}\sqrt{\Delta t}z(t)\\\Longrightarrow&dx=\underbrace{-\frac{1}{2}\beta(t)x(t)}_{\mathbb{f(x,t)}}dt+\underbrace{\sqrt{\beta(t)}}_{g(t)}dw\end{aligned}\nonumber
⟹x(t+Δt)−x(t)=−21β(t)x(t)Δt+β(t)Δtz(t)dx=f(x,t)
−21β(t)x(t)dt+g(t)
β(t)dw
由此,我们便得到了DDPM的SDE表达形式,它被表达成VP SDE
②NCSN:
我们同样也将NCSN的加噪过程写成SDE的表达形式
首先回忆一下NCSN的加噪
x
~
=
x
+
σ
z
\tilde x = x+\sigma z
x~=x+σz
其中
z
z
z服从标准正态分布,
σ
\sigma
σ是噪声强度,在NCSN中,我们是对原始图像x加上不同强度的噪声
σ
i
\sigma_i
σi,一共有N个强度噪声
{
σ
i
}
i
=
1
N
\{\sigma_i\}_{i=1}^N
{σi}i=1N。
其实这种每次都对原始图像加不同噪声,可以理解和DDPM那般,一点点的加入噪声,每个加噪时刻就对应不同强度的噪声。两者是等价的,稍微想一想就知道了,于是我们可以表达加噪过程为
q
(
x
i
∣
x
i
−
1
)
q(x_i|x_{i-1})
q(xi∣xi−1)
x
i
=
x
i
−
1
+
σ
i
2
−
σ
i
−
1
2
z
i
−
1
x_i=x_{i-1}+\sqrt{\sigma_i^2-\sigma_{i-1}^2}z_{i-1}
xi=xi−1+σi2−σi−12zi−1
为什么是这样呢?我们不妨证明一下就知道了(证明过程与DDPM类似)
由之前在DDPM里面讲过的正太分布定理,有
σ
i
2
−
σ
i
−
1
2
z
i
−
1
∼
N
(
0
,
σ
i
2
−
σ
i
−
1
2
)
\sqrt{\sigma_i^2-\sigma_{i-1}^2}z_{i-1} \sim N(0,\sigma_i^2-\sigma_{i-1}^2)
σi2−σi−12zi−1∼N(0,σi2−σi−12)
E
[
x
i
]
=
E
[
x
i
−
1
]
+
E
[
σ
i
2
−
σ
i
−
1
2
z
i
−
1
]
=
x
i
−
1
\mathbb{E}[x_i]=\mathbb{E}[x_{i-1}]+\mathbb{E}\left[\sqrt{\sigma_i^2-\sigma_{i-1}^2}z_{i-1}\right]=x_{i-1}
E[xi]=E[xi−1]+E[σi2−σi−12zi−1]=xi−1
V a r [ x i ] = V a r [ x i − 1 ] + V a r [ σ i 2 − σ i − 1 2 z i − 1 ] = σ i 2 − σ i − 1 2 Var\left[x_i\right]=Var[x_{i-1}]+Var\left[\sqrt{\sigma_i^2-\sigma_{i-1}^2}z_{i-1}\right]=\sigma_{i}^2-\sigma_{i-1}^2 Var[xi]=Var[xi−1]+Var[σi2−σi−12zi−1]=σi2−σi−12
所以有 q ( x i ∣ x i − 1 ) ∼ N ( x i − 1 , σ i 2 − σ i − 1 2 ) q(x_i|x_{i-1})\sim N(x_{i-1},\sigma_{i}^2-\sigma_{i-1}^2) q(xi∣xi−1)∼N(xi−1,σi2−σi−12)
由
q
(
x
i
−
1
∣
x
i
−
2
)
q(x_{i-1}|x_{i-2})
q(xi−1∣xi−2)可得
x
i
−
1
=
x
i
−
2
+
σ
i
−
1
2
−
σ
i
−
2
2
z
i
−
2
x_{i-1}=x_{i-2}+\sqrt{\sigma_{i-1}^2-\sigma_{i-2}^2}z_{i-2}
xi−1=xi−2+σi−12−σi−22zi−2
结合上面两个式子可得
x
i
=
x
i
−
2
+
σ
i
−
1
2
−
σ
i
−
2
2
z
i
−
2
+
σ
i
2
−
σ
i
−
1
2
z
i
−
1
=
x
i
−
2
+
σ
i
2
−
σ
i
−
2
2
z
i
\begin{aligned}x_i=&x_{i-2}+\sqrt{\sigma_{i-1}^2-\sigma_{i-2}^2}z_{i-2}+\sqrt{\sigma_i^2-\sigma_{i-1}^2}z_{i-1}\\=&x_{i-2}+\sqrt{\sigma_i^2-\sigma^2_{i-2}}z_{i}\end{aligned}\nonumber
xi==xi−2+σi−12−σi−22zi−2+σi2−σi−12zi−1xi−2+σi2−σi−22zi
第二个等号的原因是正太分布的可加性
σ
i
−
1
2
−
σ
i
−
2
2
z
i
−
2
+
σ
i
2
−
σ
i
−
1
2
z
i
−
1
∼
N
(
0
,
(
σ
i
2
−
σ
i
−
1
2
)
+
(
σ
i
−
1
2
−
σ
i
−
2
2
)
)
=
N
(
0
,
σ
i
2
−
σ
i
−
2
2
)
\sqrt{\sigma_{i-1}^2-\sigma_{i-2}^2}z_{i-2}+\sqrt{\sigma_i^2-\sigma_{i-1}^2}z_{i-1}\sim N(0,(\sigma_i^2-\sigma_{i-1}^2)+(\sigma_{i-1}^2-\sigma_{i-2}^2))=N(0,\sigma_{i}^2-\sigma_{i-2}^2)
σi−12−σi−22zi−2+σi2−σi−12zi−1∼N(0,(σi2−σi−12)+(σi−12−σi−22))=N(0,σi2−σi−22)
这是对应
q
(
x
i
∣
x
i
−
2
)
q(x_i|x_{i-2})
q(xi∣xi−2),也就是
q
(
x
i
∣
x
i
−
2
)
∼
N
(
x
i
;
x
i
−
2
,
σ
i
2
−
σ
i
−
2
2
)
q(x_i|x_{i-2})\sim N(x_i;x_{i-2},\sigma_{i}^2-\sigma_{i-2}^2)
q(xi∣xi−2)∼N(xi;xi−2,σi2−σi−22)
以此类推得到 q ( x i ∣ x 0 ) ∼ N ( x i ; x 0 , σ i 2 − σ 0 2 ) = N ( x i ; x 0 , σ i 2 ) q(x_i|x_{0})\sim N(x_i;x_{0},\sigma_{i}^2-\sigma_{0}^2)=N(x_i;x_{0},\sigma_{i}^2) q(xi∣x0)∼N(xi;x0,σi2−σ02)=N(xi;x0,σi2)
回忆一下NCSN里面的 q σ i ( x ~ ∣ x ) q_{\sigma_i}(\tilde x|x) qσi(x~∣x),可以看到加噪过程完全符合 q σ i ( x ~ ∣ x ) q_{\sigma_i}(\tilde x|x) qσi(x~∣x)。比如 q ( x i ∣ x 0 ) q(x_i|x_0) q(xi∣x0),也完全符合噪声强度为 σ i \sigma_i σi,期望为 x 0 x_0 x0的情况( σ 0 = 0 \sigma_0=0 σ0=0)。这很容易理解,稍微想一想就明白了。
现在我们开始将其转化为SDE的形式
x
(
t
+
Δ
t
)
=
x
(
t
)
+
σ
2
(
t
+
Δ
t
)
−
σ
2
(
t
)
z
(
t
)
=
x
(
t
)
+
σ
2
(
t
+
Δ
t
)
−
σ
2
(
t
)
Δ
t
Δ
t
z
(
t
)
\begin{aligned}x(t+\Delta t)=&x(t)+\sqrt{\sigma^2(t+\Delta t)-\sigma^2(t)}z(t)\\=&x(t)+\sqrt{\frac{\sigma^2(t+\Delta t)-\sigma^2(t)}{\Delta t}\Delta t}z(t)\end{aligned}\nonumber
x(t+Δt)==x(t)+σ2(t+Δt)−σ2(t)z(t)x(t)+Δtσ2(t+Δt)−σ2(t)Δtz(t)
移项,再把第二项的
Δ
t
\Delta t
Δt提出来
x
(
t
+
Δ
t
)
−
x
(
t
)
=
σ
2
(
t
+
Δ
t
)
−
σ
2
(
t
)
Δ
t
Δ
t
z
(
t
)
x(t+\Delta t)-x(t)=\sqrt{\frac{\sigma^2(t+\Delta t)-\sigma^2(t)}{\Delta t}}\sqrt{\Delta t}z(t)
x(t+Δt)−x(t)=Δtσ2(t+Δt)−σ2(t)Δtz(t)
当
Δ
t
→
0
\Delta t\to 0
Δt→0,则有
d
x
=
0
⏟
f
(
x
,
t
)
+
d
[
σ
2
(
t
)
]
d
t
⏟
g
(
t
)
d
w
dx = \underbrace{0}_{\mathbb{f(x,t)}}+\underbrace{\sqrt{\frac{d\left[\sigma^2(t)\right]}{dt}}}_{g(t)}\mathbb{dw}
dx=f(x,t)
0+g(t)
dtd[σ2(t)]dw
由此,我们便得到了NCSN的SDE表达式,它被表达为 VE SDE
③sub-VP SDE:
作者在此基础上,又提出了一个新的SDE,称为sub-VP SDE
d
x
=
−
1
2
β
(
t
)
x
⏟
f
(
x
,
t
)
d
t
+
β
(
t
)
(
1
−
e
−
2
∫
0
t
β
(
s
)
d
s
)
⏟
g
(
t
)
d
w
\mathbb{dx=\underbrace{-\frac{1}{2}\beta(t)x}_{f(x,t)}dt+\underbrace{\sqrt{\beta(t)\left(1-e^{-2\int_0^t\beta(s)ds}\right)}}_{g(t)}dw}
dx=f(x,t)
−21β(t)xdt+g(t)
β(t)(1−e−2∫0tβ(s)ds)dw
4.2、扰动核
有了上述过程DDPM和NCSN的连续化表达,那我们还需要做什么呢?
我们需要求出扩散核,我们注意到模型训练的时候,神经网络的输入是 x i x_i xi和 i i i( x i x_i xi)是加噪后的图像, i i i是时刻(忘记了的请看第二节)。当扩散到连续之后,神经网络的输入就变成了 x ( t ) x(t) x(t)和 t t t。
我们需要计算出连续型的扩散核 q 0 t ( x ( t ) ∣ x ( 0 ) ) q_{0t}(x(t)|x(0)) q0t(x(t)∣x(0))(实际上,论文写的是 P 0 t ( x ( t ) ∣ x ( 0 ) ) P_{0t}(x(t)|x(0)) P0t(x(t)∣x(0)),我在此直接沿用DDPM和NCSN里面的写法)
由于DDPM和NCSN里面的加噪过程是一个线性高斯加噪。
当时间被无限压缩,就相当于有无限个高斯分布相加,其结果仍然是高斯,因此 q 0 t ( x ( t ) ∣ x ( 0 ) ) q_{0t}(x(t)|x(0)) q0t(x(t)∣x(0))仍然是高斯
那么该如何求其参数呢?事实上,论文并没有对这一部分进行推导,而是给出了一篇论文里面的公式,然后告诉我们可以依靠这个公式,求出扰动核的期望跟方差。所以我在此处就不推导了。感兴趣的,可以看参考①,这位大佬推导过了。以下,我直接给出结论
q
0
t
(
x
(
t
)
∣
x
(
0
)
)
=
{
N
(
x
(
t
)
;
x
(
0
)
,
[
σ
2
(
t
)
−
σ
2
(
0
)
]
I
)
,
(
V
E
S
D
E
)
N
(
x
(
t
)
;
x
(
0
)
e
−
1
2
∫
0
t
β
(
s
)
d
s
,
I
−
I
e
−
∫
0
t
β
(
s
)
d
s
)
,
(
V
P
S
D
E
)
N
(
x
(
t
)
;
x
(
0
)
e
−
1
2
∫
0
t
β
(
s
)
d
s
,
[
1
−
e
−
∫
0
t
β
(
s
)
d
s
]
2
I
)
,
(
s
u
b
−
V
P
S
D
E
)
q_{0t}(x(t)|x(0))=\begin{cases}\mathcal{N}\left(x(t);x(0),[\sigma^2(t)-\sigma^2(0)]I\right),\quad (VE\quad SDE) \\\mathcal{N}\left(x(t);x(0)e^{-\frac{1}{2}\int_0^t \beta(s)ds},I-Ie^{-\int_0^t\beta(s)ds}\right),\quad (VP\quad SDE)\\\mathcal{N}\left(x(t);x(0)e^{-\frac{1}{2}\int_0^t\beta(s)ds},\left[1-e^{-\int_0^t\beta(s)ds}\right]^2I\right),\quad(sub-VP\quad SDE)\end{cases}\nonumber
q0t(x(t)∣x(0))=⎩
⎨
⎧N(x(t);x(0),[σ2(t)−σ2(0)]I),(VESDE)N(x(t);x(0)e−21∫0tβ(s)ds,I−Ie−∫0tβ(s)ds),(VPSDE)N(x(t);x(0)e−21∫0tβ(s)ds,[1−e−∫0tβ(s)ds]2I),(sub−VPSDE)
对于时间t,我们之前提到
t
∈
[
0
,
1
]
t\in[0,1]
t∈[0,1],但是,VE SDE在t=0处是不连续的(
σ
(
0
)
≠
σ
(
0
+
)
\sigma(0)\neq \sigma(0^{+})
σ(0)=σ(0+))
因此,会构造一个极小的值,使得 t ∈ [ ϵ , 1 ] t\in[\epsilon,1] t∈[ϵ,1],其中 ϵ \epsilon ϵ一般取 1 e − 5 1e-5 1e−5。后面作者经过实验发现,VP SDE和sub-VP SDE也把t的最小值设定为 ϵ \epsilon ϵ会获得更好的似然和稳定性,故而它们的t取值范围与VE SDE一样(VP SDE采样生成时 ϵ = 1 e − 3 \epsilon=1e-3 ϵ=1e−3)
有了这三个,就可以求出
x
(
t
)
x(t)
x(t),但仍然有一个问题,里面的
σ
(
t
)
、
β
(
s
)
\sigma(t)、\beta(s)
σ(t)、β(s),我们仍然没有说明是什么。在离散的时候,它们对应一个时刻具体的值。而在连续的时候,我们用一个关于时间的表达式去表示
{
σ
(
t
)
=
σ
min
(
σ
m
a
x
σ
m
i
n
)
t
β
(
t
)
=
β
ˉ
m
i
n
+
t
(
β
ˉ
m
a
x
−
β
ˉ
m
i
n
)
(7)
\begin{cases}\sigma(t)=\sigma_{\min}\left(\frac{\sigma_{max}}{\sigma_{min}}\right)^t\\\beta(t)=\bar\beta_{min}+t(\bar\beta_{max}-\bar\beta_{min})\end{cases}\tag{7}
⎩
⎨
⎧σ(t)=σmin(σminσmax)tβ(t)=βˉmin+t(βˉmax−βˉmin)(7)
σ
(
t
)
\sigma(t)
σ(t)之所以这样取值,是因为
{
σ
i
}
i
N
\{\sigma_i\}_i^{N}
{σi}iN是一个等比数列,稍微想一想就明白了;而
β
(
t
)
\beta(t)
β(t)之所以这样取,是因为一般我们的
{
β
i
}
i
N
\{\beta_i\}_i^{N}
{βi}iN是一个等差数列。
把Eq.(7)代入三个SDE和对应的扰动核
VE SDE:
d
x
=
σ
m
i
n
(
σ
m
a
x
σ
m
i
n
)
t
2
log
σ
m
a
x
σ
m
i
n
d
w
q
0
t
(
x
(
t
)
∣
x
(
0
)
)
=
N
(
x
(
t
)
;
x
(
0
)
,
σ
m
i
n
2
(
σ
m
a
x
σ
m
i
n
)
2
t
I
)
\mathbb{dx=\sigma_{min}\left(\frac{\sigma_{max}}{\sigma_{min}}\right)^t\sqrt{2\log\frac{\sigma_{max}}{\sigma_{min}}}dw}\\ q_{0t}(x(t)|x(0))=\mathcal{N}\left(x(t);x(0),\sigma_{min}^2\left(\frac{\sigma_{max}}{\sigma_{min}}\right)^{2t}I\right)\nonumber
dx=σmin(σminσmax)t2logσminσmaxdwq0t(x(t)∣x(0))=N(x(t);x(0),σmin2(σminσmax)2tI)
SDE里面的导数用到导数法则
y
=
a
x
→
y
′
=
a
x
ln
a
y=a^x\to y'=a^x\ln a
y=ax→y′=axlna
VP SDE:
d
x
=
−
1
2
(
β
ˉ
m
i
n
+
t
(
β
ˉ
m
a
x
−
β
ˉ
m
i
n
)
)
x
d
t
+
β
ˉ
m
i
n
+
t
(
β
ˉ
m
a
x
−
β
ˉ
m
i
n
)
d
w
q
0
t
(
x
(
t
)
∣
x
(
0
)
)
=
N
(
x
(
t
)
;
e
−
1
4
t
2
(
β
ˉ
m
a
x
−
β
ˉ
m
i
n
)
−
1
2
t
β
ˉ
m
i
n
x
(
0
)
,
I
−
I
e
−
1
2
t
2
(
β
ˉ
m
a
x
−
β
ˉ
m
i
n
)
−
t
β
ˉ
m
i
n
)
\mathbb{dx=-\frac{1}{2}\left(\bar\beta_{min}+t(\bar\beta_{max}-\bar\beta_{min})\right)xdt+\sqrt{\bar\beta_{min}+t(\bar\beta_{max}-\bar\beta_{min})}dw}\\q_{0t}(x(t)|x(0))=\mathcal{N}\left(x(t);e^{-\frac{1}{4}t^2(\bar\beta_{max}-\bar\beta_{min})-\frac{1}{2}t\bar\beta_{min}}x(0),I-Ie^{-\frac{1}{2}t^2(\bar\beta_{max}-\bar\beta_{min})-t\bar\beta_{min}}\right)\nonumber
dx=−21(βˉmin+t(βˉmax−βˉmin))xdt+βˉmin+t(βˉmax−βˉmin)dwq0t(x(t)∣x(0))=N(x(t);e−41t2(βˉmax−βˉmin)−21tβˉminx(0),I−Ie−21t2(βˉmax−βˉmin)−tβˉmin)
sub-VP SDE:
d
x
=
−
1
2
(
β
ˉ
m
i
n
+
t
(
β
ˉ
m
a
x
−
β
ˉ
m
i
n
)
)
x
d
t
+
(
β
ˉ
m
i
n
+
t
(
β
ˉ
m
a
x
−
β
ˉ
m
i
n
)
)
(
1
−
e
−
t
2
(
β
ˉ
m
a
x
−
β
ˉ
m
i
n
)
−
2
t
β
ˉ
m
i
n
)
d
w
q
0
t
(
x
(
t
)
∣
x
(
0
)
)
=
N
(
x
(
t
)
;
e
−
1
4
t
2
(
β
ˉ
m
a
x
−
β
ˉ
m
i
n
)
−
1
2
t
β
ˉ
m
i
n
x
(
0
)
,
[
1
−
e
−
1
2
t
2
(
β
ˉ
m
a
x
−
β
ˉ
m
i
n
)
−
t
β
ˉ
m
i
n
]
2
I
)
\mathbb{dx=-\frac{1}{2}\left(\bar\beta_{min}+t(\bar\beta_{max}-\bar\beta_{min})\right)x}dt+\sqrt{\left(\bar\beta_{min}+t(\bar\beta_{max}-\bar\beta_{min})\right)\left(1-e^{-t^2(\bar\beta_{max}-\bar\beta_{min})-2t\bar\beta_{min} }\right)}dw\\q_{0t}(x(t)|x(0))=\mathcal{N}\left(x(t);e^{-\frac{1}{4}t^2(\bar\beta_{max}-\bar\beta_{min})-\frac{1}{2}t\bar\beta_{min}}x(0),\left[1-e^{-\frac{1}{2}t^2(\bar\beta_{max}-\bar\beta_{min})-t\bar\beta_{min}}\right]^2I\right)\nonumber
dx=−21(βˉmin+t(βˉmax−βˉmin))xdt+(βˉmin+t(βˉmax−βˉmin))(1−e−t2(βˉmax−βˉmin)−2tβˉmin)dwq0t(x(t)∣x(0))=N(x(t);e−41t2(βˉmax−βˉmin)−21tβˉminx(0),[1−e−21t2(βˉmax−βˉmin)−tβˉmin]2I)
VP SDE、sub-VP SDE用到牛顿-莱布尼兹公式 ∫ a b f ( x ) d x = F ( b ) − F ( a ) = F ( x ) ∣ a b \int_a^bf(x)dx=F(b)-F(a)=F(x)|_a^b ∫abf(x)dx=F(b)−F(a)=F(x)∣ab, F ( x ) F(x) F(x)是 f ( x ) f(x) f(x)的原函数
有了三个连续型的扩散核,我们随机采样一个时刻t,依据扩散核我们就可以得到 x ( t ) x(t) x(t),然后就可以训练了。
4.3、反向过程连续化
可以训练了,那么现在该如何去采样呢?
这篇论文Reverse-time diffusion equation models表明,扩散过程的反向过程也是一个扩散过程
反向随机微分方程表示为
d
x
=
[
f
(
x
,
t
)
−
g
(
t
)
2
∇
x
log
p
t
(
x
)
]
d
t
+
g
(
t
)
d
w
ˉ
(8)
\mathbb{dx}=\left[\mathbb{f(x,t)}-g(t)^2\nabla_x\log p_t(x)\right]\mathbb{dt}+g(t)\mathbb{d\bar w}\tag{8}
dx=[f(x,t)−g(t)2∇xlogpt(x)]dt+g(t)dwˉ(8)
其中,
d
w
ˉ
\mathbb{d\bar w}
dwˉ是一个标准的维纳过程,只不过是反向的,时间从T到0;而
d
t
\mathbb{dt}
dt是一个无穷小的负时间步长
d
t
=
−
Δ
t
dt=-\Delta t
dt=−Δt,
∇
x
log
p
t
(
x
)
\nabla_x\log p_t(x)
∇xlogpt(x)就是所谓的分数(在NCSN里面讲过,其实训练的时候近似的
s
θ
s_\theta
sθ)
因此,只要我们用神经网络估算出 s θ ( x ( t ) , t ) ≈ ∇ x log p t ( x ) s_\theta(x(t),t)\approx\nabla_x\log p_t(x) sθ(x(t),t)≈∇xlogpt(x)即可得到反向SDE
接下来,我们先证明这个反向SDE与我们之前的采样步骤的联系(本步骤你可以不看,毕竟给定正向过程,反向过程的SDE是唯一确定的,所以按理说我们没必要证明了(原论文其实也没有这一段)。但为了完整,我还是写吧)
①DDPM:
f
(
x
,
t
)
=
−
1
2
β
(
t
)
x
(
t
)
,
g
(
t
)
=
β
(
t
)
\mathbb{f(x,t)}=-\frac{1}{2}\beta(t)x(t),g(t)=\sqrt{\beta(t)}
f(x,t)=−21β(t)x(t),g(t)=β(t)
d
x
=
[
f
(
x
,
t
)
−
g
(
t
)
2
∇
x
log
p
t
(
x
)
]
d
t
+
g
(
t
)
d
w
ˉ
=
[
−
1
2
β
(
t
)
x
(
t
)
−
β
(
t
)
∇
x
log
p
t
(
x
)
]
d
t
+
β
(
t
)
d
w
ˉ
=
β
(
t
)
[
−
1
2
x
(
t
)
−
∇
x
log
p
t
(
x
)
]
d
t
+
β
(
t
)
d
w
ˉ
\begin{aligned}\mathbb{dx}=&\left[\mathbb{f(x,t)}-g(t)^2\nabla_x\log p_t(x)\right]\mathbb{dt}+g(t)\mathbb{d\bar w}\\=&\left[-\frac{1}{2}\beta(t)x(t)-\beta(t)\nabla_x\log p_t(x)\right]\mathbb{dt}+\sqrt{\beta(t)}\mathbb{d\bar w}\\=&\beta(t)\left[-\frac{1}{2}x(t)-\nabla_x\log p_t(x)\right]\mathbb{dt}+\sqrt{\beta(t)}\mathbb{d\bar w}\end{aligned}\nonumber
dx===[f(x,t)−g(t)2∇xlogpt(x)]dt+g(t)dwˉ[−21β(t)x(t)−β(t)∇xlogpt(x)]dt+β(t)dwˉβ(t)[−21x(t)−∇xlogpt(x)]dt+β(t)dwˉ
即
x
(
t
−
Δ
t
)
−
x
(
t
)
=
β
(
t
)
[
1
2
x
(
t
)
+
∇
x
log
p
t
(
x
)
]
Δ
t
+
β
(
t
)
Δ
t
z
(
t
)
⟹
x
(
t
−
Δ
t
)
=
x
(
t
)
+
β
(
t
)
[
1
2
x
(
t
)
+
∇
x
log
p
t
(
x
)
]
Δ
t
+
β
(
t
)
Δ
t
z
(
t
)
x(t-\Delta t)-x(t)=\beta(t)\left[\frac{1}{2}x(t)+\nabla_x\log p_t(x)\right]\mathbb{\Delta t}+\sqrt{\beta(t)\Delta t}z(t)\\\Longrightarrow x(t-\Delta t) =x(t)+\beta(t)\left[\frac{1}{2}x(t)+\nabla_x\log p_t(x)\right]\mathbb{\Delta t}+\sqrt{\beta(t)\Delta t}z(t)\nonumber
x(t−Δt)−x(t)=β(t)[21x(t)+∇xlogpt(x)]Δt+β(t)Δtz(t)⟹x(t−Δt)=x(t)+β(t)[21x(t)+∇xlogpt(x)]Δt+β(t)Δtz(t)
进一步把等式右侧的
x
(
t
)
x(t)
x(t)结合
x
(
t
−
Δ
t
)
=
[
1
+
1
2
β
(
t
)
Δ
t
]
x
(
t
)
+
β
(
t
)
Δ
t
∇
x
log
p
t
(
x
)
+
β
(
t
)
Δ
t
z
(
t
)
≈
[
1
+
1
2
β
(
t
)
Δ
t
]
x
(
t
)
+
β
(
t
)
Δ
t
∇
x
log
p
t
(
x
)
+
(
β
(
t
)
Δ
t
)
2
2
∇
x
log
p
t
(
x
)
+
β
(
t
)
Δ
t
z
(
t
)
=
[
1
+
1
2
β
(
t
)
Δ
t
]
x
(
t
)
+
[
1
+
β
(
t
)
Δ
t
2
]
β
(
t
)
Δ
t
∇
x
log
p
t
(
x
)
+
β
(
t
)
Δ
t
z
(
t
)
=
[
1
+
β
(
t
)
Δ
t
2
]
(
x
(
t
)
+
β
(
t
)
Δ
t
∇
x
log
P
t
(
x
)
)
+
β
(
t
)
Δ
t
z
(
t
)
≈
[
1
1
−
β
(
t
)
Δ
t
]
(
x
(
t
)
+
β
(
t
)
Δ
t
∇
x
log
P
t
(
x
)
)
+
β
(
t
)
Δ
t
z
(
t
)
\begin{aligned}x(t-\Delta t)=&\left[1+\frac{1}{2}\beta(t)\Delta t\right]x(t)+\beta(t)\Delta t\nabla_x\log p_t(x)+\sqrt{\beta(t)\Delta t}z(t)\\\approx&\left[1+\frac{1}{2}\beta(t)\Delta t\right]x(t)+\beta(t)\Delta t\nabla_x\log p_t(x)+\frac{(\beta(t)\Delta t)^2}{2}\nabla_x\log p_t(x)+\sqrt{\beta(t)\Delta t}z(t)\\=&\left[1+\frac{1}{2}\beta(t)\Delta t\right]x(t)+\left[1+\frac{\beta(t)\Delta t}{2}\right]\beta(t)\Delta t\nabla_x\log p_t(x)+\sqrt{\beta(t)\Delta t}z(t)\\=&\left[1+\frac{\beta(t)\Delta t}{2}\right]\left(x(t)+\beta(t)\Delta t \nabla_x\log P_t(x)\right)+\sqrt{\beta(t)\Delta t}z(t)\\\approx&\left[\frac{1}{\sqrt{1-\beta(t)\Delta t}}\right]\left(x(t)+\beta(t)\Delta t \nabla_x\log P_t(x)\right)+\sqrt{\beta(t)\Delta t}z(t)\end{aligned}\nonumber
x(t−Δt)=≈==≈[1+21β(t)Δt]x(t)+β(t)Δt∇xlogpt(x)+β(t)Δtz(t)[1+21β(t)Δt]x(t)+β(t)Δt∇xlogpt(x)+2(β(t)Δt)2∇xlogpt(x)+β(t)Δtz(t)[1+21β(t)Δt]x(t)+[1+2β(t)Δt]β(t)Δt∇xlogpt(x)+β(t)Δtz(t)[1+2β(t)Δt](x(t)+β(t)Δt∇xlogPt(x))+β(t)Δtz(t)[1−β(t)Δt1](x(t)+β(t)Δt∇xlogPt(x))+β(t)Δtz(t)
第一个约等号是因为
β
(
t
)
Δ
t
≪
1
\beta(t)\Delta t \ll 1
β(t)Δt≪1。第二个约等于是因为泰勒展开
令 x ( t − Δ t ) = x i − 1 , x ( t ) = x i , β ( t ) = β ˉ i , Δ t = 1 N , z ( t ) = z i x(t-\Delta t)=x_{i-1},x(t)=x_i,\beta(t)=\bar \beta_i,\Delta t=\frac{1}{N},z(t)=z_i x(t−Δt)=xi−1,x(t)=xi,β(t)=βˉi,Δt=N1,z(t)=zi,并由正向过程曾定义过 β i = 1 N β ˉ i = Δ t β ( t ) \beta_i=\frac{1}{N}\bar\beta_i=\Delta t\beta(t) βi=N1βˉi=Δtβ(t),又有 ∇ x log P t ( x ) ≈ s θ \nabla_x\log P_t(x)\approx s_\theta ∇xlogPt(x)≈sθ
将其代入可得
x
i
−
1
=
1
1
−
β
i
(
x
i
+
β
i
s
θ
)
+
β
i
z
i
x_{i-1}=\frac{1}{\sqrt{1-\beta_i}}\left(x_i+\beta_is_\theta\right)+\sqrt{\beta_i}z_i
xi−1=1−βi1(xi+βisθ)+βizi
可以看到这就是DDPM的采样方式(祖先采样)
接下来我们来看NCSN的
对于NCSN我们所要推导出来的并不是郎之万动力采样。而是NCSN的祖先采样(
q
(
x
i
−
1
∣
x
i
,
x
0
)
q(x_{i-1}|x_i,x_0)
q(xi−1∣xi,x0)),依据马尔可夫性质,它是完全可以算出来的(计算方法与DDPM一样),以下我直接给出结论(不懂的可以看一下DDPM那篇论文,或者看论文)
q
(
x
i
−
1
∣
x
i
,
x
0
)
=
N
(
x
i
−
1
;
σ
i
−
1
2
σ
i
2
x
i
+
(
1
−
σ
i
−
1
2
σ
i
2
)
x
0
,
σ
i
−
1
2
(
σ
i
2
−
σ
i
−
1
2
)
σ
i
2
I
)
q(x_{i-1}|x_i,x_0)=\mathcal{N}\left(x_{i-1};\frac{\sigma^2_{i-1}}{\sigma^2_{i}}x_i+(1-\frac{\sigma_{i-1}^2}{\sigma^2_{i}})x_0,\frac{\sigma^2_{i-1}(\sigma^2_{i}-\sigma^2_{i-1})}{\sigma^2_{i}}I\right)
q(xi−1∣xi,x0)=N(xi−1;σi2σi−12xi+(1−σi2σi−12)x0,σi2σi−12(σi2−σi−12)I)
同样与DDPM那边,把里面的
x
0
x_0
x0用加噪过程
x
i
=
x
0
+
σ
i
z
x_i=x_0+\sigma_iz
xi=x0+σiz代入即可把期望变成
μ
=
x
i
(
x
0
,
z
)
+
(
σ
i
2
−
σ
i
−
1
2
)
s
\mu=x_i(x_0,z)+(\sigma^2_{i}-\sigma^2_{i-1})s
μ=xi(x0,z)+(σi2−σi−12)s
其中
s
=
−
z
σ
i
s=-\frac{z}{\sigma_i}
s=−σiz
所以采样方法就是
x
i
−
1
=
x
i
+
(
σ
i
2
−
σ
i
−
1
2
)
s
θ
(
x
i
,
i
)
+
σ
i
−
1
2
(
σ
i
2
−
σ
i
−
1
2
)
σ
i
2
z
x_{i-1}=x_i+(\sigma^2_{i}-\sigma^2_{i-1})s_\theta(x_i,i)+\sqrt{\frac{\sigma^2_{i-1}(\sigma^2_{i}-\sigma^2_{i-1})}{\sigma^2_{i}}}z
xi−1=xi+(σi2−σi−12)sθ(xi,i)+σi2σi−12(σi2−σi−12)z
我们在DDPM那里说过,对于方差的选择,其实会选择成加噪过程的方差,所以有
x
i
−
1
=
x
i
+
(
σ
i
2
−
σ
i
−
1
2
)
s
θ
(
x
i
,
i
)
+
(
σ
i
2
−
σ
i
−
1
2
)
z
x_{i-1}=x_i+(\sigma^2_{i}-\sigma^2_{i-1})s_\theta(x_i,i)+\sqrt{(\sigma^2_{i}-\sigma^2_{i-1})}z
xi−1=xi+(σi2−σi−12)sθ(xi,i)+(σi2−σi−12)z
②NCSN:
f
(
x
,
t
)
=
0
,
g
(
t
)
=
d
[
σ
2
(
t
)
]
d
t
\mathbb{f(x,t)=0},g(t)=\sqrt{\frac{d\left[\sigma^2(t)\right]}{dt}}
f(x,t)=0,g(t)=dtd[σ2(t)]
d
x
=
[
f
(
x
,
t
)
−
d
[
σ
2
(
t
)
]
d
t
∇
x
log
p
t
(
x
)
]
d
t
+
g
(
t
)
d
w
ˉ
=
−
d
[
σ
2
(
t
)
]
d
t
∇
x
log
p
t
(
x
)
d
t
+
d
[
σ
2
(
t
)
]
d
t
d
w
ˉ
\begin{aligned}\mathbb{dx}=&\left[\mathbb{f(x,t)}-\sqrt{\frac{d\left[\sigma^2(t)\right]}{dt}}\nabla_x\log p_t(x)\right]\mathbb{dt}+g(t)\mathbb{d\bar w}\\=&-\frac{d\left[\sigma^2(t)\right]}{dt}\nabla_x\log p_t(x)\mathbb{dt}+\sqrt{\frac{d\left[\sigma^2(t)\right]}{dt}}\mathbb{d\bar w}\end{aligned}\nonumber
dx==[f(x,t)−dtd[σ2(t)]∇xlogpt(x)]dt+g(t)dwˉ−dtd[σ2(t)]∇xlogpt(x)dt+dtd[σ2(t)]dwˉ
即
x
(
t
−
Δ
t
)
=
x
(
t
)
+
σ
2
(
t
)
−
σ
2
(
t
−
Δ
t
)
Δ
t
∇
x
log
p
t
(
x
)
Δ
t
+
σ
2
(
t
)
−
σ
2
(
t
−
Δ
t
)
Δ
t
Δ
t
z
(
t
)
=
x
(
t
)
+
(
σ
2
(
t
)
−
σ
2
(
t
−
Δ
t
)
)
∇
x
log
p
t
(
x
)
+
σ
2
(
t
)
−
σ
2
(
t
−
Δ
t
)
z
(
t
)
\begin{aligned}x(t-\Delta t)=&x(t)+\frac{\sigma^2(t)-\sigma^2(t-\Delta t)}{\Delta t}\nabla_x\log p_t(x)\mathbb{\Delta t}+\sqrt{\frac{\sigma^2(t)-\sigma^2(t-\Delta t)}{\Delta t}\Delta t}\mathbb{z(t)}\\=&x(t)+\left(\sigma^2(t)-\sigma^2(t-\Delta t)\right)\nabla_x\log p_t(x)+\sqrt{\sigma^2(t)-\sigma^2(t-\Delta t)}z(t)\end{aligned}\nonumber
x(t−Δt)==x(t)+Δtσ2(t)−σ2(t−Δt)∇xlogpt(x)Δt+Δtσ2(t)−σ2(t−Δt)Δtz(t)x(t)+(σ2(t)−σ2(t−Δt))∇xlogpt(x)+σ2(t)−σ2(t−Δt)z(t)
如同DDPM那般,便可得到NCSN的采样方法
x
i
−
1
=
x
i
+
(
σ
i
2
−
σ
i
−
1
2
)
s
θ
+
σ
i
2
−
σ
i
−
1
2
z
i
(9)
x_{i-1}=x_i+(\sigma^2_i-\sigma^2_{i-1})s_\theta+\sqrt{\sigma^2_i-\sigma^2_{i-1}}z_i\tag{9}
xi−1=xi+(σi2−σi−12)sθ+σi2−σi−12zi(9)
由此,证明完毕,与NCSN里面的方式是一样的。
5、SDE数值求解器
5.1、反向扩散采样器
这一节针对采样问题
反向过程是一个SDE,那么问题就来了,我们肯定需要采样,如果采样就必须设定一定的采样步骤(肯定不能是连续,不离散化我们怎么采样)。那么该如何离散化呢?这就是要涉及数值求解器,不同的离散化方案就对应不同的数值求解器(采样方法)
我们曾证明DDPM和NCSN的采样方法(祖先采样)可以连续化成SDE。那么很显然DDPM和NCSN里面的祖先采样就是一种离散化方案,对应一个采样过程
论文假设前向加噪过程已经进行了离散化预设,即加噪时刻为 i ∈ { 0 , 1 , ⋯ , N − 1 } i\in\{0,1,\cdots,N-1\} i∈{0,1,⋯,N−1}
DDPM:
由
d
x
=
[
f
(
x
,
t
)
−
g
(
t
)
2
∇
x
log
p
t
(
x
)
]
d
t
+
g
(
t
)
d
w
ˉ
(10)
\mathbb{dx}=\left[\mathbb{f(x,t)}-g(t)^2\nabla_x\log p_t(x)\right]\mathbb{dt}+g(t)\mathbb{d\bar w}\tag{10}
dx=[f(x,t)−g(t)2∇xlogpt(x)]dt+g(t)dwˉ(10)
时间步
d
t
=
−
Δ
t
dt=-\Delta t
dt=−Δt,由我们之前所说
Δ
t
β
(
t
)
=
Δ
t
β
ˉ
i
=
β
i
\Delta t\beta(t)=\Delta t\bar\beta_i=\beta_i
Δtβ(t)=Δtβˉi=βi,
f
(
x
,
t
)
=
−
1
2
β
(
t
)
x
(
t
)
\mathbb{f(x,t)}=-\frac{1}{2}\beta(t)x(t)
f(x,t)=−21β(t)x(t),
g
(
t
)
=
β
(
t
)
g(t)=\sqrt{\beta(t)}
g(t)=β(t),所以我们可以得到
x
(
t
−
Δ
t
)
−
x
(
t
)
=
1
2
β
(
t
)
Δ
t
x
(
t
)
+
β
(
t
)
Δ
t
∇
x
log
p
t
(
x
)
+
β
(
t
)
Δ
t
z
(
t
)
x(t-\Delta t)-x(t)=\frac{1}{2}\beta(t)\Delta tx(t)+\beta(t)\Delta t\nabla_x\log p_t(x)+\sqrt{\beta(t)\Delta t}z(t)
x(t−Δt)−x(t)=21β(t)Δtx(t)+β(t)Δt∇xlogpt(x)+β(t)Δtz(t)
把
x
(
t
)
x(t)
x(t)移动到等式右边,并写成下标的形式,对数梯度写成分数s
x
i
−
1
=
(
1
+
1
2
β
i
)
x
i
+
β
i
s
θ
(
x
i
,
i
)
+
β
i
z
i
=
[
2
−
(
1
−
1
2
β
i
)
]
x
i
+
β
i
s
θ
(
x
i
,
i
)
+
β
i
z
i
≈
(
2
−
1
−
β
i
)
x
i
+
β
i
s
θ
(
x
i
,
i
)
+
β
i
z
i
(11)
\begin{aligned}x_{i-1}=&(1+\frac{1}{2}\beta_i)x_{i}+\beta_is_\theta(x_i,i)+\sqrt{\beta_i}z_i\\=&\left[2-(1-\frac{1}{2}\beta_i)\right]x_{i}+\beta_is_\theta(x_i,i)+\sqrt{\beta_i}z_i\\\approx&\left(2-\sqrt{1-\beta_i}\right)x_i+\beta_is_\theta(x_i,i)+\sqrt{\beta_i}z_i\end{aligned}\tag{11}
xi−1==≈(1+21βi)xi+βisθ(xi,i)+βizi[2−(1−21βi)]xi+βisθ(xi,i)+βizi(2−1−βi)xi+βisθ(xi,i)+βizi(11)
最后一个约等号,是泰勒展开(
β
i
→
0
\beta_i\to 0
βi→0)。由此我们得到了这个采样步骤
而对NCSN,也是一样的道理
NCSN:
f
(
x
,
t
)
=
0
,
g
(
t
)
=
d
[
σ
2
(
t
)
]
d
t
\mathbb{f(x,t)=0},g(t)=\sqrt{\frac{d\left[\sigma^2(t)\right]}{dt}}
f(x,t)=0,g(t)=dtd[σ2(t)]
x
(
t
−
Δ
t
)
−
x
(
t
)
=
(
σ
i
2
−
σ
i
−
1
2
)
Δ
t
Δ
t
∇
x
log
p
t
(
x
)
+
(
σ
i
2
−
σ
i
−
1
2
)
Δ
t
Δ
t
z
(
t
)
\begin{aligned}x(t-\Delta t)-x(t)=&\frac{(\sigma_i^2-\sigma_{i-1}^2)}{\Delta t}\Delta t\nabla_x\log p_t(x)+\sqrt{\frac{(\sigma_i^2-\sigma_{i-1}^2)}{\Delta t}\Delta t}z(t)\end{aligned}\nonumber
x(t−Δt)−x(t)=Δt(σi2−σi−12)Δt∇xlogpt(x)+Δt(σi2−σi−12)Δtz(t)
整理后可得
x
i
−
1
=
x
i
+
(
σ
i
2
−
σ
i
−
1
2
)
s
θ
(
x
i
,
i
)
+
σ
i
2
−
σ
i
−
1
2
z
i
(12)
x_{i-1}=x_i+(\sigma_i^2-\sigma_{i-1}^2)s_\theta(x_i,i)+\sqrt{\sigma_i^2-\sigma_{i-1}^2}z_i\tag{12}
xi−1=xi+(σi2−σi−12)sθ(xi,i)+σi2−σi−12zi(12)
事实上,如果我们有一个离散化的方案,可以把Eq.(10)变成如下形式(把离散的时间写入f,g的下标)
x
i
−
1
=
x
i
−
f
i
(
x
i
)
+
g
i
2
s
θ
(
s
i
,
i
)
+
g
i
z
i
(13)
x_{i-1}=x_{i}-\mathbb{f}_{i}(x_i)+g_i^2s_\theta(s_i,i)+g_iz_i\tag{13}
xi−1=xi−fi(xi)+gi2sθ(si,i)+gizi(13)
基于Eq.(13)而形成的采样方法(如Eq.(11)、Eq(12)),论文统称为反向扩散采样器
5.2、其他离散化方案
不同的离散化方案,造就了不同的采样方式。作者列举了目前比较出名的采样方式
比如欧拉-丸山法(Euler-Maruyama)和龙格-库塔法( Runge-Kutta)。感兴趣的自行百度即可。
5.3、预测-校正采样器
不同的离散化方案总是存在一定的误差,因为时间本身是连续的。为了减小这种误差,我们总是寻找较好的离散化方案。但再好的方案也都是会存在些许误差。为了消去这些误差,论文提出预测-矫正采样器
具体来说,我们都知道,使用SDE数值求解器,可以得到某个时间的离散化样本估计(此为预测),由于是离散化的,他总是存在一些误差。如何消去这些误差呢?当燃是使用MCMC的方法(我在之前可能没有提到过,其实就是马尔可夫蒙特卡洛,比如之前的郎之万动力采样就是其中的一个类型,该方法可以让点走到概率最高点附近),故而把该步骤称为校正
蓝色部分为预测(先预测一个离散时刻的点),黄色部分为校正(对预测其的输出进行MCMC校正,使其走到概率最高点附近)
6、概率流ODE
对每个反向SDE,都存在一个对应的ODE(常微分),其实就是把SDE的随机项去掉。把随机项去掉之后,就变成了一个确定的过程。论文对这个确定的ODE进行了推导,得到了下面ODE(感兴趣的自己看吧,字数太多又发不了了)
d
x
=
[
f
(
x
,
t
)
−
1
2
g
(
t
)
2
∇
x
log
p
t
(
x
)
]
d
t
\mathbb{dx=}\left[\mathbb{f(x,t)}-\frac{1}{2}g(t)^2\nabla_x\log p_t(x)\right]\mathbb{dt}
dx=[f(x,t)−21g(t)2∇xlogpt(x)]dt
我们可以得到和SDE那样的离散化采样器,求法都是一样的(把对应的f,g的值代进去),所以可以得到
DDPM:
x
i
−
1
=
(
2
−
1
−
β
i
)
x
i
+
1
2
β
i
s
θ
(
x
i
,
i
)
x_{i-1}=(2-\sqrt{1-\beta_i})x_i+\frac{1}{2}\beta_{i}s_\theta(x_i,i)
xi−1=(2−1−βi)xi+21βisθ(xi,i)
NCSN:
x
i
−
1
=
x
i
+
1
2
(
σ
i
2
−
σ
i
−
1
2
)
s
θ
(
x
i
,
i
)
x_{i-1}=x_{i}+\frac{1}{2}(\sigma_i^2-\sigma_{i-1}^2)s_\theta(x_i,i)
xi−1=xi+21(σi2−σi−12)sθ(xi,i)
除此之外,通过其他离散化方案,使用黑盒ODE求解器,能够让采样过程变得高效。
当燃了,虽然高效,但是在没有校正器的情况下,他们的FID(一个衡量指标)往往不如SDE的求解器
7、其他内容
①除此之外,我们在之前的正向SDE中,曾给出过公式,里面的
g
(
t
)
g(t)
g(t)只跟时间有关,其实更一般的,他还跟x有关,所以
d
x
=
f
(
x
,
t
)
d
t
+
G
(
x
,
t
)
d
w
\mathbb{dx=f(x,t)dt}+G(x,t)\mathbb{dw}
dx=f(x,t)dt+G(x,t)dw
其中
G
(
⋅
,
t
)
:
R
d
→
R
d
×
d
G(\cdot,t):\mathbb{R}^d\to \mathbb{R}^{d\times d}
G(⋅,t):Rd→Rd×d
同样的反向过程可以表示为
d
x
=
[
f
(
x
,
t
)
−
∇
G
(
x
,
t
)
G
(
x
,
t
)
T
]
−
G
(
x
,
t
)
G
(
x
,
t
)
T
∇
x
log
p
t
(
x
)
d
t
+
G
(
x
,
t
)
d
w
ˉ
\mathbb{dx}=\left[f(x,t)-\nabla G(x,t)G(x,t)^T\right]-G(x,t)G(x,t)^T\nabla_x\log p_t(x)\mathbb{dt}+G(x,t)\mathbb{d\bar w}
dx=[f(x,t)−∇G(x,t)G(x,t)T]−G(x,t)G(x,t)T∇xlogpt(x)dt+G(x,t)dwˉ
②论文里面还提到的条件生成的SDE,感兴趣的自己去看看
③论文还提到的似然的计算,感兴趣的自己去看看
④论文还提到了很多的其他的小细节,感兴趣的自己去看看
8、DDIM附加
这篇论文并没有讲到DDIM,但是DDIM提到了这篇论文。DDIM里面说到他们的方法ODE化之后等价于VE SDE的ODE形式(但采样方法不同),里面有所证明,感兴趣的可以看看Denoising Diffusion Implicit Models (arxiv.org)
9、结束
好了,本篇文章到此为止,如有误entire,还望指出,阿里嘎多!
10、参考
Score-based SDE 扩散生成模型从入门到出师系列(二):揭秘随机微分方程如何应用于采样生成 - 知乎 (zhihu.com)