目录
SDE
Stochastic Differential Equation,在多个噪声尺度上扰动数据是SMLD和DDPM成功的关键,SDE将其泛化到连续时间域,也即无限噪声尺度。
前向过程
用于数据扰动,训练得分模型。
d
x
=
f
(
x
,
t
)
d
t
+
g
(
t
)
d
t
dx = f(x, t)dt + g(t)dt
dx=f(x,t)dt+g(t)dt
f是x(t)的drift系数,g是x(t)的diffusion系数。
逆向过程
初始化噪声,采样生成样本。
d
x
=
[
f
(
x
,
t
)
−
g
(
t
)
2
Δ
x
l
o
g
p
t
(
x
)
]
d
t
+
g
(
t
)
d
w
ˉ
dx = [f(x, t) - g(t)^2\Delta_xlogp_t(x)]dt + g(t)d\bar{w}
dx=[f(x,t)−g(t)2Δxlogpt(x)]dt+g(t)dwˉ
目标函数
θ ∗ = a r g m i n θ E t { λ ( t ) E x ( 0 ) E x ( t ) ∣ x ( 0 ) [ ∣ ∣ s θ ( x ( t ) , t ) − Δ x ( t ) l o g p σ ( t ) ( x ( t ) ∣ x ( 0 ) ) ∣ ∣ 2 2 ] } \theta^{*} = \mathop{argmin}\limits_{\theta}E_t\{\lambda(t)E_{x(0)}E_{x(t)|x(0)}[||s_\theta(x(t), t)-\Delta_{x(t)}log_{p_\sigma(t)}(x(t)|x(0))||_2^2]\} θ∗=θargminEt{λ(t)Ex(0)Ex(t)∣x(0)[∣∣sθ(x(t),t)−Δx(t)logpσ(t)(x(t)∣x(0))∣∣22]}
Δ
x
~
l
o
g
q
(
x
~
∣
x
)
=
Δ
x
~
l
o
g
1
2
π
σ
2
e
x
p
{
−
∣
∣
x
~
−
x
∣
∣
2
2
σ
2
}
=
Δ
x
~
(
l
o
g
1
2
π
σ
2
−
∣
∣
x
~
−
x
∣
∣
2
2
σ
2
)
=
−
x
~
−
x
σ
2
=
−
z
σ
\begin{aligned} \Delta_{\widetilde{x}}log q(\widetilde{x}|x) &= \Delta_{\widetilde{x}} log\frac{1}{\sqrt{2\pi \sigma^2}}exp\{-\frac{||\widetilde{x}-x||^2}{2\sigma^2}\} \\ &= \Delta_{\widetilde{x}} (log\frac{1}{\sqrt{2\pi \sigma^2}} - \frac{||\widetilde{x}-x||^2} {2\sigma^2})\\ &=-\frac{\widetilde{x} - x}{\sigma^2} \\ &=-\frac{z}{\sigma} \end{aligned}
Δx
logq(x
∣x)=Δx
log2πσ21exp{−2σ2∣∣x
−x∣∣2}=Δx
(log2πσ21−2σ2∣∣x
−x∣∣2)=−σ2x
−x=−σz
所以SDE的loss代码实现为:
if not likelihood_weighting:
losses = torch.square(score * std[:, None, None, None, None] + z) # 用\sigma^2加权loss
losses = torch.mean(losses.reshape(losses.shape[0], -1), dim=-1)# [N,]
else:
g2 = sde.sde(torch.zeros_like(batch, t)[1] ** 2
losses = torch.square(score + z / std[:, None, None, None])
losses = torch.mean(losses.reshape(losses.shape[0], -1), dim=-1) * g2
losses = torch.mean(losses)
if中用的SDE推荐的loss加权方式,else中参考了Maximum Likelihood Training of Score-Based Diffusion Models
这篇工作中的loss加权方式,默认使用的是else中的加权方式。
逆向过程求解
s θ s_{\theta} sθ训好后,我们用它构建逆向SDE,通过数值方法生成符合 p 0 p_0 p0数据分布的样本。
通用数值ODE求解器
数值求解器提供了SDE的近似轨迹。存在许多通用的SDE数值求解方法,比如Euler-Maruyama和随机Runge-Kutta方法。
在CIFAR-10数据集上,reverse diffusion samplers比SMLD和DDPM中用的ancestral sampling要略好一些。
Predictor-Corrector采样器
有点复杂,而且巨慢,有需要再看。
Probability Flow ODE
对于所有的扩散过程,存在一个对应的确定性过程,该过程的轨迹和SDE共享边际概率密度
{
p
t
(
x
)
}
t
=
0
T
\{p_t(x)\}_{t=0}^T
{pt(x)}t=0T。该确定性过程满足一个ODE:
d
x
=
[
f
(
x
,
t
)
−
1
2
g
(
t
)
2
Δ
x
l
o
g
p
t
(
x
)
]
d
t
dx = [f(x, t) - \frac{1}{2}g(t)^2\Delta_xlogp_t(x)]dt
dx=[f(x,t)−21g(t)2Δxlogpt(x)]dt
我们称该ODE为probability flow ODE。
似然计算
操纵隐式表征
高效采样
VE, VP and SUB-VP SDEs推导过程
SMLD和DDPM的噪声扰动分别是Variance Exploding(VE)和Variance Preserving(VP) SDEs的离散化形式。sub-VP SDEs做为VP SDEs的修正,常常可以使生成的样本质量和似然表现更好。
VE SDEs
假设总计有N个噪声尺度,SMLD的每个扰动核 p σ i ( x i ∣ x 0 ) p_{\sigma_i}(x_i|x_0) pσi(xi∣x0)可从下述马尔可夫链推导出来:
x
i
=
x
i
−
1
+
σ
i
2
−
σ
i
−
1
2
z
i
−
1
,
i
=
1
,
.
.
.
,
N
(20)
x_i = x_{i-1} + \sqrt{\sigma^2_i - \sigma^2_{i-1}}z_{i-1}, i = 1, ..., N \tag{20}
xi=xi−1+σi2−σi−12zi−1,i=1,...,N(20)
where,
z
i
−
1
∼
N
(
0
,
I
)
,
x
0
∼
p
d
a
t
a
,
σ
0
=
0
z_{i-1}\sim N(\bold{0}, \bold{I)}, x_0 \sim p_{data}, \sigma_0=0
zi−1∼N(0,I),x0∼pdata,σ0=0, 当
N
N
N ->
∞
\infty
∞时,马尔可夫链
{
x
i
}
i
=
1
N
\{x_i\}_{i=1}^N
{xi}i=1N变为连续随机过程
{
x
(
t
)
}
t
=
0
1
\{x(t)\}^1_{t=0}
{x(t)}t=01,
{
σ
i
}
i
=
1
N
\{\sigma_i\}_{i=1}^N
{σi}i=1N变为函数
σ
(
t
)
\sigma(t)
σ(t),
z
i
z_i
zi变为
z
(
t
)
z(t)
z(t), 这里我们使用连续时间变量
t
∈
[
0
,
1
]
t\in[0, 1]
t∈[0,1]做索引,而不是整数
i
∈
{
1
,
2
,
.
.
.
,
N
}
i\in\{1, 2, ..., N\}
i∈{1,2,...,N}。
x
(
i
N
)
=
x
i
x(\frac{i}{N})=x_i
x(Ni)=xi,
σ
(
i
N
)
=
σ
i
\sigma(\frac{i}{N})=\sigma_i
σ(Ni)=σi,
z
(
i
N
)
=
z
i
z(\frac{i}{N})=z_i
z(Ni)=zi for i = 1, 2, …, N。可用
Δ
t
=
1
N
\Delta t=\frac{1}{N}
Δt=N1和
t
∈
{
0
,
1
N
,
.
.
.
,
N
−
1
N
}
t \in\{0, \frac{1}{N}, ..., \frac{N-1}{N}\}
t∈{0,N1,...,NN−1}重写上式(20):
x
(
t
+
Δ
t
)
=
x
(
t
)
+
σ
2
(
t
+
Δ
t
)
−
σ
2
(
t
)
z
(
t
)
≈
x
(
t
)
+
d
[
σ
2
(
t
)
]
d
t
Δ
t
z
(
t
)
x(t+\Delta t) = x(t) + \sqrt{\sigma^2(t+\Delta t) - \sigma^2(t)} z(t) \approx x(t) + \sqrt{\frac{d[\sigma^2(t)]}{dt}\Delta t} z(t)
x(t+Δt)=x(t)+σ2(t+Δt)−σ2(t)z(t)≈x(t)+dtd[σ2(t)]Δtz(t)
这里的增量
Δ
t
z
(
t
)
∼
N
(
0
,
Δ
t
)
\sqrt{\Delta t}z(t)\sim N(0, \Delta t)
Δtz(t)∼N(0,Δt)所构成的随机过程,天然满足维纳过程,进而可得:
d
x
=
d
[
σ
2
(
t
)
]
d
t
d
w
(21)
dx = \sqrt{\frac{d[\sigma^2(t)]}{dt}} dw \tag{21}
dx=dtd[σ2(t)]dw(21)
即不存在drift量。
VP SDEs
DDPM中用到的扰动核
{
p
σ
i
(
x
i
∣
x
0
)
}
i
=
1
N
\{p_{\sigma_i}(x_i|x_0)\}_{i=1}^N
{pσi(xi∣x0)}i=1N,离散马尔可夫链为:
x
i
=
1
−
β
i
x
i
−
1
+
β
i
z
i
−
1
,
i
=
1
,
.
.
.
,
N
(22)
x_i = \sqrt{1-\beta_i}x_{i-1} + \sqrt{\beta_i}z_{i-1}, i = 1,..., N \tag{22}
xi=1−βixi−1+βizi−1,i=1,...,N(22)
where,
z
i
−
1
∼
N
(
0
,
I
)
z_{i-1}\sim N(\bold{0}, \bold{I)}
zi−1∼N(0,I). 为了获得N->
∞
\infty
∞时, 该马尔可夫链的极限,我们定义噪声尺度的辅助集合
{
β
ˉ
i
=
N
β
i
}
i
=
1
N
\{\bar{\beta}_i = N\beta_i\}_{i=1}^N
{βˉi=Nβi}i=1N,并重写式(22)如下:
x
i
=
1
−
β
ˉ
i
N
x
i
−
1
+
β
ˉ
i
N
z
i
−
1
,
i
=
1
,
.
.
.
,
N
(23)
x_i = \sqrt{1-\frac{\bar{\beta}_i}{N}}x_{i-1} + \sqrt{\frac{\bar{\beta}_i}{N}}z_{i-1}, i = 1,..., N \tag{23}
xi=1−Nβˉixi−1+Nβˉizi−1,i=1,...,N(23)
当N趋于无穷大时,
{
β
ˉ
i
=
N
β
i
}
i
=
1
N
\{\bar{\beta}_i = N\beta_i\}_{i=1}^N
{βˉi=Nβi}i=1N成为以
t
∈
[
0
,
1
]
t\in[0, 1]
t∈[0,1]为索引的函数
β
(
t
)
\beta(t)
β(t),
β
(
i
N
)
=
β
ˉ
i
\beta(\frac{i}{N})=\bar{\beta}_i
β(Ni)=βˉi,
x
(
i
N
)
=
x
i
x(\frac{i}{N})=x_i
x(Ni)=xi, ,
z
(
i
N
)
=
z
i
z(\frac{i}{N})=z_i
z(Ni)=zi,可用
Δ
t
=
1
N
\Delta t=\frac{1}{N}
Δt=N1和
t
∈
{
0
,
1
N
,
.
.
.
,
N
−
1
N
}
t \in\{0, \frac{1}{N}, ..., \frac{N-1}{N}\}
t∈{0,N1,...,NN−1}重写上式(23):
x
(
t
+
Δ
t
)
=
1
−
β
(
t
+
Δ
t
)
Δ
t
x
(
t
)
+
β
(
t
+
Δ
t
)
Δ
t
z
(
t
)
≈
x
(
t
)
−
1
2
β
(
t
+
Δ
t
)
Δ
t
x
(
t
)
+
β
(
t
+
Δ
t
)
Δ
t
z
(
t
)
≈
x
(
t
)
−
1
2
β
(
t
)
Δ
t
x
(
t
)
+
β
(
t
)
Δ
t
z
(
t
)
(24)
\begin{aligned} x(t+\Delta t) &= \sqrt{1-\beta(t+\Delta t)\Delta t} x(t) + \sqrt{\beta(t+\Delta t)\Delta t}z(t)\\ &\approx x(t) - \frac{1}{2}\beta(t+\Delta t)\Delta t x(t) + \sqrt{\beta(t+\Delta t)\Delta t}z(t)\\ &\approx x(t) - \frac{1}{2}\beta(t)\Delta tx(t) + \sqrt{\beta(t)\Delta t}z(t) \tag{24} \end{aligned}
x(t+Δt)=1−β(t+Δt)Δtx(t)+β(t+Δt)Δtz(t)≈x(t)−21β(t+Δt)Δtx(t)+β(t+Δt)Δtz(t)≈x(t)−21β(t)Δtx(t)+β(t)Δtz(t)(24)
上式用到了泰勒近似,当
Δ
t
\Delta t
Δt << 1时,上式中的近似相等成立。因此当
Δ
\Delta
Δ -> 0时,式(24)收敛到下述VP SDE:
d
x
=
−
1
2
β
(
t
)
x
d
t
+
β
(
t
)
d
w
(25)
dx = - \frac{1}{2}\beta(t)xdt + \sqrt{\beta(t)}dw \tag{25}
dx=−21β(t)xdt+β(t)dw(25)
当 t t t-> ∞ \infty ∞时,VE SDE是个方差爆炸的过程。相比之下,VP SDE过程的方差是有界的,此外,当 p ( x ( 0 ) ) p(x(0)) p(x(0))是单位方差时,对于所有 t ∈ [ 0 , ∞ ) t\in[0, \infty) t∈[0,∞),该过程是固定的单位方差。
根据数理基础可得:
d
∑
V
P
(
t
)
d
t
=
β
(
t
)
(
I
−
∑
V
P
(
t
)
)
\frac{d\sum\nolimits_{VP}(t)}{dt} = \beta(t)(I - \sum\nolimits_{VP}(t))
dtd∑VP(t)=β(t)(I−∑VP(t))
∑
V
P
(
t
)
\sum\nolimits_{VP}(t)
∑VP(t)是VP SDE
x
(
t
)
x(t)
x(t)的协方差,解这个ODE可得:
∑
V
P
(
t
)
=
I
+
e
∫
0
t
−
β
(
s
)
d
s
(
∑
V
P
(
0
)
−
I
)
\sum\nolimits_{VP}(t) = I + e^{\int_0^t-\beta(s)ds}(\sum\nolimits_{VP}(0) - I)
∑VP(t)=I+e∫0t−β(s)ds(∑VP(0)−I)
据此可见,给定
∑
V
P
(
0
)
\sum\nolimits_{VP}(0)
∑VP(0),
∑
V
P
(
t
)
\sum\nolimits_{VP}(t)
∑VP(t)总是有界的,此外,当
∑
V
P
(
0
)
=
I
\sum\nolimits_{VP}(0)=I
∑VP(0)=I时,
∑
V
P
(
t
)
\sum\nolimits_{VP}(t)
∑VP(t)恒等于
I
I
I。
受VP-SDE启发,提出了新的SDE叫sub-VP SDE:
d
x
=
−
1
2
β
(
t
)
x
d
t
+
β
(
t
)
(
1
−
e
−
2
∫
0
t
β
(
s
)
d
s
)
d
w
(25)
dx = - \frac{1}{2}\beta(t)xdt + \sqrt{\beta(t)(1-e^{-2\int_0^t\beta(s)ds})}dw \tag{25}
dx=−21β(t)xdt+β(t)(1−e−2∫0tβ(s)ds)dw(25)
VP和sub-VP SDE的期望
E
[
x
(
t
)
]
E[x(t)]
E[x(t)]是相同的。方差函数不同:
∑
s
u
b
−
V
P
(
t
)
=
I
+
e
−
2
∫
0
t
β
(
s
)
d
s
I
+
e
−
∫
0
t
β
(
s
)
d
s
(
∑
s
u
b
−
V
P
(
0
)
−
2
I
)
\sum\nolimits_{sub-VP}(t) = I + e^{-2\int_0^t\beta(s)ds}I+ e^{-\int_0^t\beta(s)ds}(\sum\nolimits_{sub-VP}(0) - 2I)
∑sub−VP(t)=I+e−2∫0tβ(s)dsI+e−∫0tβ(s)ds(∑sub−VP(0)−2I)
可以有如下发现:
- 当 ∑ s u b − V P ( 0 ) = ∑ V P ( 0 ) \sum\nolimits_{sub-VP}(0) = \sum\nolimits_{VP}(0) ∑sub−VP(0)=∑VP(0) 并共享 β ( s ) \beta(s) β(s)时,对于所有的 t ≥ 0 t\geq0 t≥0, ∑ s u b − V P ( t ) ≤ ∑ V P ( t ) \sum\nolimits_{sub-VP}(t) \leq \sum\nolimits_{VP}(t) ∑sub−VP(t)≤∑VP(t) 。
- 如果
l
i
m
t
−
>
∞
∫
0
t
β
(
s
)
d
s
=
∞
lim_{t->\infty}\int_0^t\beta(s)ds=\infty
limt−>∞∫0tβ(s)ds=∞,
l
i
m
t
−
>
∞
∑
s
u
b
−
V
P
(
t
)
=
l
i
m
t
−
>
∞
∑
V
P
(
t
)
=
I
lim_{t->\infty}\sum\nolimits_{sub-VP}(t) = lim_{t->\infty}\sum\nolimits_{VP}(t) = I
limt−>∞∑sub−VP(t)=limt−>∞∑VP(t)=I。
从1可知,我们为什么称之为sub-VP SDE, 因为其方差总是被VP SDE限定。
这三个SDE的扰动核 p σ t ( x ( t ) ∣ x ( 0 ) ) p_{\sigma_t}(x(t)|x(0)) pσt(x(t)∣x(0))如下:
其实这三个就是代码中的marginal_prob
p
t
(
x
)
p_t(x)
pt(x),VE SDE计算mean和std非常简单,不做进一步解释。这里讲下(sub-)VP SDE中的积分项,关于
β
(
t
)
\beta(t)
β(t),我们定义了其两端值(0, 1e-4)
, (1, 2e-2)
,并且需要满足递增性,这样的函数形式有无数种,代码中选择了最简单的线性函数形式,
β
(
t
)
=
(
β
1
−
β
0
)
∗
t
+
β
0
\beta(t)=(\beta_1-\beta_0)*t + \beta_0
β(t)=(β1−β0)∗t+β0,则其在t时刻的积分值为
1
2
t
2
(
β
1
−
β
0
)
+
β
0
t
\frac{1}{2}t^2(\beta_1-\beta_0) + \beta_0 t
21t2(β1−β0)+β0t,据此就能计算出(sub-)VP SDE的mean和std,用于训练时的数据扰动。