Score-Based Generative Modeling Through Stochastic Differential Equations
Yang Song, Stanford University, ICLR2021, Cited:723, Code: 无, Paper.
目录子
1. 前言
这篇文章是关于一种新的生成建模方法,它通过随机微分方程(SDE)将复杂的数据分布平滑地转换为已知的先验分布。通过逆向时间SDE,可以将先验分布转换回数据分布。这种方法依赖于扰动数据分布的时间相关梯度场(即得分),可以使用神经网络准确估计这些得分,并使用数值SDE求解器生成样本。该框架包含了以前在基于得分的生成建模和扩散概率建模中的方法,并允许新的采样过程和新的建模能力。
2. 方法
2.1 SDE推导
随机微分方程SDE是一种微分方程,其中一个或多个项是随机过程,其解本身也是一个随机过程。SDE用于对股价和利率波动等现象进行建模。推导过程如下,假设我们有一个确定性的微分方程:
d
x
=
f
(
x
)
d
t
\begin{equation} dx = f(x)dt \end{equation}
dx=f(x)dt
其中,
x
x
x是一个函数,
f
f
f是一个已知的函数。我们可以用欧拉法来离散化这个方程,得到:
x
t
+
Δ
t
−
x
t
=
f
(
x
t
)
Δ
t
\begin{equation} x_{t+\Delta t} - x_t = f(x_t)\Delta t \end{equation}
xt+Δt−xt=f(xt)Δt
其中,
Δ
t
\Delta t
Δt是一个小的时间间隔。如果我们想要在这个方程中加入一些随机性,比如说由于测量误差或者外部干扰等原因,我们可以在右边加上一个噪声项:
x
t
+
Δ
t
−
x
t
=
f
(
x
t
)
Δ
t
+
g
(
x
t
)
Δ
t
ϵ
\begin{equation} x_{t+\Delta t} - x_t = f(x_t)\Delta t + g(x_t)\sqrt{\Delta t}\epsilon \end{equation}
xt+Δt−xt=f(xt)Δt+g(xt)Δtϵ
其中,
g
g
g是一个已知的函数,
ϵ
\epsilon
ϵ是一个服从标准正态分布的随机变量。这样,我们就得到了一个离散形式的SDE。如果我们让
Δ
t
→
0
\Delta t \to 0
Δt→0,那么我们就可以得到连续形式的SDE的一般形式:
d
x
=
f
(
x
)
d
t
+
g
(
x
)
d
W
t
\begin{equation} dx = f(x)dt + g(x)dW_t \end{equation}
dx=f(x)dt+g(x)dWt
本身扩散模型就属于一个随机过程,那么用SDE描述扩散过程便是一个自然的事。
2.2 基于SDE的扩散过程
d
x
=
f
(
x
,
t
)
d
t
+
g
(
t
)
d
w
\begin{equation} dx = f(x,t)dt + g(t)dw \end{equation}
dx=f(x,t)dt+g(t)dw
其中,
f
(
x
,
t
)
f(x,t)
f(x,t)是漂移项,表示确定性的变化;
g
(
t
)
d
w
g(t)dw
g(t)dw是扩散项,表示随机性的变化。
d
w
dw
dw表示维纳过程(也称为布朗运动)的增量。飘逸项是从前一个状态到下一个状态的变化量。扩散项中的
d
w
dw
dw指的是加入的噪声,
g
(
t
)
g(t)
g(t)是加噪的强度。
2.3 基于SDE的扩散重建
一下推导来自B站台UP主VictorYuki,将公式5写为离散形式
x
t
+
Δ
t
=
x
t
+
f
(
x
t
,
t
)
Δ
t
+
g
(
t
)
Δ
t
ϵ
\begin{equation} x_{t+\Delta t} = x_t + f(x_t,t)\Delta t + g(t)\sqrt{\Delta t }\epsilon \end{equation}
xt+Δt=xt+f(xt,t)Δt+g(t)Δtϵ
这里
ϵ
∼
N
(
0
,
1
)
\epsilon \sim N(0,1)
ϵ∼N(0,1)。
x
t
+
Δ
t
x_{t+\Delta t}
xt+Δt在给定
x
t
x_{t}
xt时的概率分布,当
ϵ
\epsilon
ϵ是一个标准正态随机变量,在这种情况下,
x
t
+
Δ
t
x_{t+\Delta t}
xt+Δt服从一个正态分布,均值为
x
t
+
f
(
x
t
,
t
)
Δ
t
x_t + f(x_t,t)\Delta t
xt+f(xt,t)Δt,方差为
g
2
(
t
)
Δ
t
I
g^{2}(t)\Delta tI
g2(t)ΔtI。
p
(
x
t
+
Δ
t
∣
x
t
)
∼
N
(
f
(
x
t
,
t
)
Δ
t
,
g
2
(
t
)
Δ
t
I
)
\begin{equation} p(x_{t+\Delta t}|x_{t}) \sim N(f(x_t,t)\Delta t , g^{2}(t)\sqrt{\Delta t }I) \end{equation}
p(xt+Δt∣xt)∼N(f(xt,t)Δt,g2(t)ΔtI)
要从
t
+
Δ
t
t+\Delta t
t+Δt得到
x
t
x_{t}
xt,需要求解:
p
(
x
t
∣
x
t
+
Δ
t
)
=
p
(
x
t
+
Δ
t
∣
x
t
)
p
(
x
t
)
/
p
(
x
t
+
Δ
t
)
=
p
(
x
t
+
Δ
t
∣
x
t
)
e
x
p
{
l
o
g
p
(
x
t
)
−
l
o
g
p
(
x
t
+
Δ
t
)
}
\begin{align} p(x_{t}|x_{t+\Delta t})&=p(x_{t+\Delta t}|x_{t})p(x_{t})/p(x_{t+\Delta t})\\ &= p(x_{t+\Delta t}|x_{t})exp\{logp(x_{t})-logp(x_{t+\Delta t})\} \end{align}
p(xt∣xt+Δt)=p(xt+Δt∣xt)p(xt)/p(xt+Δt)=p(xt+Δt∣xt)exp{logp(xt)−logp(xt+Δt)}
先将
l
o
g
p
(
x
t
+
Δ
t
)
logp(x_{t+\Delta t})
logp(xt+Δt)泰勒一阶展开得
l
o
g
p
(
x
t
)
+
(
x
t
+
Δ
t
−
x
t
)
∇
x
t
l
o
g
p
(
x
t
)
+
Δ
t
∂
∂
t
l
o
g
p
(
x
t
)
logp(x_{t})+(x_{t+\Delta t}-x_{t}) \nabla _{x_{t}}logp(x_{t})+\Delta t \frac{\partial }{\partial t}logp(x_{t})
logp(xt)+(xt+Δt−xt)∇xtlogp(xt)+Δt∂t∂logp(xt),带入公式9并将其第一项展开为高斯分布形式得:
E
q
.
(
9
)
∝
e
x
p
{
−
∥
x
t
+
Δ
t
−
x
t
−
f
(
x
t
,
t
)
Δ
t
∥
2
2
2
g
2
(
t
)
Δ
t
−
(
x
t
+
Δ
t
−
x
t
)
∇
x
t
l
o
g
p
(
x
t
)
−
Δ
t
∂
∂
t
l
o
g
p
(
x
t
)
}
=
e
x
p
{
−
∥
(
x
t
+
Δ
t
−
x
t
)
−
(
f
(
x
t
,
t
)
−
g
2
(
t
)
∇
x
t
l
o
g
p
(
x
t
)
)
Δ
t
∥
2
2
2
g
2
(
t
)
Δ
t
−
Δ
t
∂
∂
t
l
o
g
p
(
x
t
)
−
f
2
(
x
t
,
t
)
Δ
t
g
2
(
t
)
+
(
f
(
x
t
,
t
)
−
g
2
(
t
)
∇
x
t
l
o
g
p
(
x
t
)
)
2
Δ
t
2
g
2
(
t
)
}
{\small \begin{align} Eq. (9) &\propto exp\left \{ -\frac{\left \| x_{t+\Delta t}-x_{t}-f(x_{t},t)\Delta t \right \|_{2}^{2} }{2g^{2}(t)\Delta t} - (x_{t+\Delta t}-x_{t})\nabla_{x_{t}} logp(x_{t}) - \Delta t \frac{\partial }{\partial t}logp(x_{t}) \right \} \\ &=exp\left \{ -\frac{\left \| (x_{t+\Delta t}-x_{t}) -(f(x_{t},t)- g^{2}(t)\nabla_{x_{t}} logp(x_{t}))\Delta t \right \|_{2}^{2} }{2g^{2}(t)\Delta t} -\Delta t \frac{\partial }{\partial t}logp(x_{t}) - \frac{f^{2}(x_{t},t)\Delta t}{g^{2}(t)} + \frac{(f(x_{t},t)-g^{2}(t)\nabla_{x_{t}} logp(x_{t}))^{2}\Delta t}{2g^{2}(t)} \right \} \end{align}}
Eq.(9)∝exp{−2g2(t)Δt∥xt+Δt−xt−f(xt,t)Δt∥22−(xt+Δt−xt)∇xtlogp(xt)−Δt∂t∂logp(xt)}=exp{−2g2(t)Δt∥(xt+Δt−xt)−(f(xt,t)−g2(t)∇xtlogp(xt))Δt∥22−Δt∂t∂logp(xt)−g2(t)f2(xt,t)Δt+2g2(t)(f(xt,t)−g2(t)∇xtlogp(xt))2Δt}
因为我们是从
x
t
+
Δ
t
→
x
t
x_{t+\Delta t}\to x_{t}
xt+Δt→xt,因此我们得到的分布应为都是关于时间
t
+
Δ
t
t+\Delta t
t+Δt的,把(11)中的
t
t
t都改写,并当
Δ
t
→
0
\Delta t \to 0
Δt→0时,后面三项:
E
q
.
(
11
)
≈
e
x
p
{
−
∥
(
x
t
+
Δ
t
−
x
t
)
−
(
f
(
x
t
+
Δ
t
,
t
+
Δ
t
)
−
g
2
(
t
+
Δ
t
)
∇
x
t
+
Δ
t
l
o
g
p
(
x
t
+
Δ
t
)
)
Δ
t
∥
2
2
2
g
2
(
t
+
Δ
t
)
Δ
t
}
{\small \begin{equation} Eq. (11) \approx exp\left \{ -\frac{\left \| (x_{t+\Delta t}-x_{t}) -(f(x_{t+\Delta t},t+\Delta t)- g^{2}(t+\Delta t)\nabla_{x_{t+\Delta t}} logp(x_{t+\Delta t}))\Delta t \right \|_{2}^{2} }{2g^{2}(t+\Delta t)\Delta t}\right \} \end{equation}}
Eq.(11)≈exp{−2g2(t+Δt)Δt∥(xt+Δt−xt)−(f(xt+Δt,t+Δt)−g2(t+Δt)∇xt+Δtlogp(xt+Δt))Δt∥22}
可以得到
p
(
x
t
∣
x
t
+
Δ
t
)
p(x_{t}|x_{t+\Delta t})
p(xt∣xt+Δt)的均值为
(
x
t
+
Δ
t
−
x
t
)
−
(
f
(
x
t
+
Δ
t
,
t
+
Δ
t
)
−
g
2
(
t
+
Δ
t
)
∇
x
t
+
Δ
t
l
o
g
p
(
x
t
+
Δ
t
)
(x_{t+\Delta t}-x_{t}) -(f(x_{t+\Delta t},t+\Delta t)- g^{2}(t+\Delta t)\nabla_{x_{t+\Delta t}} logp(x_{t+\Delta t})
(xt+Δt−xt)−(f(xt+Δt,t+Δt)−g2(t+Δt)∇xt+Δtlogp(xt+Δt),方差为
2
g
2
(
t
+
Δ
t
)
Δ
t
2g^{2}(t+\Delta t)\Delta t
2g2(t+Δt)Δt。则采样的连续公式为:
d
x
=
[
f
(
x
,
t
)
−
g
2
(
t
)
∇
x
t
l
o
g
p
(
x
t
)
]
d
t
+
g
(
t
)
d
w
\begin{equation} dx=[f(x,t)-g^{2}(t)\nabla_{x_{t}}logp(x_{t})]dt+g(t)dw \end{equation}
dx=[f(x,t)−g2(t)∇xtlogp(xt)]dt+g(t)dw
离散的过程表示为:
x
t
−
1
=
x
t
−
[
f
(
x
t
,
t
)
−
g
2
(
t
)
∇
x
t
l
o
g
p
(
x
t
)
+
g
(
t
)
ϵ
\begin{equation} x_{t-1}=x_{t}-[f(x_{t},t)-g^{2}(t)\nabla_{x_{t}}logp(x_{t})+g(t) \epsilon \end{equation}
xt−1=xt−[f(xt,t)−g2(t)∇xtlogp(xt)+g(t)ϵ
2.4 Variance Exploring(VE)and Variation Preserving(VP)
SDE在理论上统一了Score-based Model (NCSN)和DDPM,他们分别对应VE and VP。VE扩散过程在扩散过程的每个步骤增加了更多的噪声,导致潜在变量的分布更广。这有助于探索不同的数据分布模式。VP扩散过程在扩散过程的每个步骤中保持方差恒定,从而导致潜在变量的球形高斯分布。这有助于保存信息并避免模式崩溃。这两种方法都旨在提高扩散模型采样的效率和准确性。然而,VE比VP具有一些优势,例如更好的似然估计、更快的收敛和更低的内存消耗。
方法 | Variance Exploring | Variation Preserving |
---|---|---|
x t = x_{t}= xt= | x 0 + σ t ϵ x_{0}+\sigma_{t}\epsilon x0+σtϵ | α t ˉ x 0 + 1 − α t ˉ ϵ \sqrt{\bar{\alpha_{t}}}x_{0}+\sqrt{1-\bar{\alpha_{t}}}\epsilon αtˉx0+1−αtˉϵ |
x t + 1 = x_{t+1}= xt+1= | x t + σ t + 1 2 − σ t 2 ϵ x_{t}+\sqrt{\sigma_{t+1}^{2}-\sigma_{t}^{2}}\epsilon xt+σt+12−σt2ϵ | 1 − β t + 1 x t + β t + 1 ϵ \sqrt{1-\beta_{t+1}}x_{t}+\sqrt{\beta_{t+1}}\epsilon 1−βt+1xt+βt+1ϵ |
我们现在要求公式6中的
f
(
x
t
,
t
)
f(x_{t},t)
f(xt,t)和
g
(
t
)
g(t)
g(t),将公式17和公式21分别与公式6比对,就可求出:
VE:
f
(
x
t
,
t
)
=
0
f(x_{t},t)=0
f(xt,t)=0和
g
(
t
)
=
d
d
t
σ
t
2
g(t)=\frac{d}{dt}\sigma_{t}^{2}
g(t)=dtdσt2。
x
t
+
Δ
t
=
x
t
+
σ
t
+
Δ
t
2
−
σ
t
2
ϵ
=
x
t
+
(
σ
t
+
Δ
t
2
−
σ
t
2
)
/
Δ
t
Δ
t
ϵ
=
x
t
+
Δ
σ
t
2
/
Δ
t
Δ
t
ϵ
\begin{align} x_{t+\Delta t}&=x_{t}+\sqrt{\sigma_{t+\Delta t}^{2}-\sigma_{t}^{2}}\epsilon\\ &= x_{t} + \sqrt{(\sigma_{t+\Delta t}^{2}-\sigma_{t}^{2})/\sqrt{\Delta t}} \sqrt{\Delta t}\epsilon\\ &=x_{t} + \sqrt{\Delta \sigma_{t}^{2}/\sqrt{\Delta t}} \sqrt{\Delta t}\epsilon \end{align}
xt+Δt=xt+σt+Δt2−σt2ϵ=xt+(σt+Δt2−σt2)/ΔtΔtϵ=xt+Δσt2/ΔtΔtϵ
VP:
f
(
x
t
,
t
)
=
−
1
2
β
(
t
)
x
t
f(x_{t},t)=-\frac{1}{2}\beta(t)x_{t}
f(xt,t)=−21β(t)xt和
g
(
t
)
=
β
(
t
)
g(t)=\sqrt{\beta(t)}
g(t)=β(t)。我们知道加噪是按照一个序列
{
β
i
}
i
=
1
T
\{\beta_{i} \}_{i=1}^{T}
{βi}i=1T进行,令
{
β
ˉ
i
=
T
β
i
}
i
=
1
T
\{\bar \beta_{i}=T\beta_{i} \}_{i=1}^{T}
{βˉi=Tβi}i=1T,则当
T
T
T趋近于无穷时,
{
β
ˉ
i
}
i
=
1
T
→
β
(
t
)
,
t
∈
[
0
,
1
]
\{\bar\beta_{i} \}_{i=1}^{T} \to \beta(t),t \in[0,1]
{βˉi}i=1T→β(t),t∈[0,1],接近一个函数。且有
β
(
i
T
)
=
β
ˉ
i
\beta(\frac{i}{T})=\bar \beta_{i}
β(Ti)=βˉi,并令
Δ
t
=
1
T
\Delta t=\frac{1}{T}
Δt=T1,那么:
x
t
+
1
=
1
−
β
ˉ
t
+
1
T
x
t
+
β
ˉ
t
+
1
T
ϵ
x
t
+
Δ
t
=
1
−
β
(
t
+
Δ
t
)
Δ
t
x
t
+
β
(
t
+
Δ
t
)
Δ
t
ϵ
≈
(
1
−
1
2
β
(
t
+
Δ
t
)
Δ
t
)
x
t
+
β
(
t
+
Δ
t
)
Δ
t
ϵ
≈
x
t
−
1
2
β
(
t
)
Δ
t
x
t
+
β
(
t
)
Δ
t
ϵ
\begin{align} x_{t+1}&=\sqrt{1-\frac{\bar \beta_{t+1}}{T}}x_{t}+\sqrt{\frac{\bar \beta_{t+1}}{T}}\epsilon \\ x_{t+\Delta t} &= \sqrt{1-\beta(t+\Delta t)\Delta t}x_{t}+\sqrt{\beta(t+\Delta t)\Delta t}\epsilon \\ &\approx (1-\frac{1}{2}\beta(t+\Delta t)\Delta t)x_{t}+\sqrt{\beta(t+\Delta t)}\sqrt{\Delta t}\epsilon \\ &\approx x_{t}-\frac{1}{2}\beta(t)\Delta tx_{t}+\sqrt{\beta(t)}\sqrt{\Delta t}\epsilon \end{align}
xt+1xt+Δt=1−Tβˉt+1xt+Tβˉt+1ϵ=1−β(t+Δt)Δtxt+β(t+Δt)Δtϵ≈(1−21β(t+Δt)Δt)xt+β(t+Δt)Δtϵ≈xt−21β(t)Δtxt+β(t)Δtϵ
至此,我们通过SDE表示出了VE和VP。
2.5 联系
Score-based Model中的score:
s
θ
(
x
t
,
t
)
s_{\theta}(x_{t},t)
sθ(xt,t)和DDPM中的Denoiser:
ϵ
θ
(
x
t
,
t
)
\epsilon_{\theta}(x_{t},t)
ϵθ(xt,t)之间有什么联系?在DDPM中,
x
t
∼
N
(
α
t
ˉ
x
0
+
1
−
α
t
ˉ
I
)
x_{t}\sim N(\sqrt{\bar{\alpha_{t}}}x_{0}+1-\bar{\alpha_{t}}I)
xt∼N(αtˉx0+1−αtˉI),则
p
(
x
t
)
∝
e
x
p
{
−
∣
∣
x
t
−
α
ˉ
t
x
0
∣
∣
2
2
2
(
1
−
α
ˉ
t
)
}
p(x_{t}) \propto exp\{-\frac{||x_{t}-\sqrt{\bar \alpha_{t}}x_{0}||_{2}^{2}}{2(1-\bar \alpha_{t})}\}
p(xt)∝exp{−2(1−αˉt)∣∣xt−αˉtx0∣∣22},将
p
(
x
t
)
p(x_{t})
p(xt)带入下面公式并求导得到score:
s
c
o
r
e
=
∇
x
t
l
o
g
p
(
x
t
)
=
−
x
t
−
α
ˉ
t
x
0
1
−
α
ˉ
t
\begin{equation} score=\nabla_{x_{t}}logp(x_{t})=-\frac{x_{t}-\sqrt{\bar \alpha_{t}}x_{0}}{1-\bar \alpha_{t}} \end{equation}
score=∇xtlogp(xt)=−1−αˉtxt−αˉtx0。
在DDPM中,
x
t
=
α
t
ˉ
x
0
+
1
−
α
t
ˉ
ϵ
x_{t}= \sqrt{\bar{\alpha_{t}}}x_{0}+\sqrt{1-\bar{\alpha_{t}}}\epsilon
xt=αtˉx0+1−αtˉϵ,可以推出
ϵ
=
x
t
−
α
t
ˉ
x
0
1
−
α
t
ˉ
\epsilon=\frac{x_{t}-\sqrt{\bar{\alpha_{t}}}x_{0}}{\sqrt{1-\bar{\alpha_{t}}}}
ϵ=1−αtˉxt−αtˉx0。且优化函数为
L
t
−
1
simple
=
E
x
0
,
ϵ
∼
N
(
0
,
I
)
[
∥
ϵ
−
ϵ
θ
(
α
ˉ
t
x
0
+
1
−
α
ˉ
t
ϵ
,
t
)
∥
2
]
\begin{equation} L_{t-1}^{\text {simple }}=\mathbb{E}_{\mathbf{x}_{0}, \epsilon \sim \mathcal{N}(0, \mathbf{I})}\left[\left\|\epsilon-\epsilon_{\theta}\left(\sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0}+\sqrt{1-\bar{\alpha}_{t}} \epsilon, t\right)\right\|^{2}\right] \end{equation}
Lt−1simple =Ex0,ϵ∼N(0,I)[
ϵ−ϵθ(αˉtx0+1−αˉtϵ,t)
2]
那么可以得到
ϵ
θ
(
x
t
,
t
)
=
x
t
−
α
ˉ
t
x
0
1
−
α
ˉ
t
\epsilon_{\theta}(x_{t},t)=\frac{x_{t}-\sqrt{\bar \alpha_{t}}x_{0}}{\sqrt{1-\bar \alpha_{t}}}
ϵθ(xt,t)=1−αˉtxt−αˉtx0,并于公式22对比可以得到他们之间的联系:
s
θ
(
x
t
,
t
)
≈
−
1
1
−
α
ˉ
t
ϵ
θ
(
x
t
,
t
)
\begin{equation} s_{\theta}(x_{t},t) \approx -\frac{1}{\sqrt{1-\bar \alpha_{t}}}\epsilon_{\theta}(x_{t},t) \end{equation}
sθ(xt,t)≈−1−αˉt1ϵθ(xt,t)
即通过这个系数就可以完成Score-based Model和DDPM的转换。
VE与VP之间的联系?VE中: x t = x 0 + σ t ϵ x_{t}=x_{0}+\sigma_{t}\epsilon xt=x0+σtϵ,两边同除: x t 1 + σ t 2 = x 0 1 + σ t 2 + σ t 1 + σ t 2 ϵ \frac{x_{t}}{ \sqrt{1+\sigma_{t}^{2}}}=\frac{x_{0}}{\sqrt{1+\sigma_{t}^{2}}} +\frac{\sigma_{t}}{\sqrt{1+\sigma_{t}^{2}}}\epsilon 1+σt2xt=1+σt2x0+1+σt2σtϵ,我们令VP中的 x ˉ t = x t 1 + σ t 2 \bar x_{t}=\frac{x_{t}}{ \sqrt{1+\sigma_{t}^{2}}} xˉt=1+σt2xt,并令 α ˉ t = 1 1 + σ t 2 \sqrt{\bar \alpha_{t}}=\frac{1}{ \sqrt{1+\sigma_{t}^{2}}} αˉt=1+σt21,这样做的意义就构建他们之间的联系,就是通过VE中的 σ t \sigma_{t} σt来获得VP中的加噪序列 α t \alpha_{t} αt,最后 1 − α ˉ t = σ t 1 + σ t 2 \sqrt{1-\bar \alpha_{t}}=\frac{\sigma_{t}}{ \sqrt{1+\sigma_{t}^{2}}} 1−αˉt=1+σt2σt,则VP中: x t = α ˉ t x 0 + 1 − α ˉ t ϵ x_{t}=\sqrt{\bar \alpha_{t}}x_{0}+\sqrt{1-\bar \alpha_{t}}\epsilon xt=αˉtx0+1−αˉtϵ。