[note] Training with Weighted Sum of Denoising Score Matching Objectives
利用 去噪分数匹配目标的加权和 进行训练,去噪指的是使用sde的方法就不需要自行补充噪声了。
本文的目的是解释如何对原始数据进行扰动。 from https://yang-song.github.io/blog/2021/score/
一、理论
首先,挑选一个随机过程(SDE)对原始数据分布 p 0 p_0 p0进行扰动得到扰动后数据的概率密度分布 p t p_t pt。
本文选择的随机过程为:
d
x
=
σ
t
d
w
,
t
∈
[
0
,
1
]
d{\bf x} = \sigma^td{\bf w}, \ t\in[0,1]
dx=σtdw, t∈[0,1]
在这种情况下,扰动后数据的概率密度分布
p
t
p_t
pt,在原始数据下的条件概率分布为:
p
0
t
(
x
(
t
)
∣
x
(
0
)
)
=
N
(
x
(
t
)
;
x
(
0
)
,
1
2
log
σ
(
σ
2
t
−
1
)
I
)
p_{0t}(\mathbf{x}(t) \mid \mathbf{x}(0)) = \mathcal{N}\bigg(\mathbf{x}(t); \mathbf{x}(0), \frac{1}{2\log \sigma}(\sigma^{2t} - 1) \mathbf{I}\bigg)
p0t(x(t)∣x(0))=N(x(t);x(0),2logσ1(σ2t−1)I)
关于这个函数的解释是,使用参数$ \frac{1}{2\log \sigma}(\sigma^{2t} - 1)
作
为
我
们
的
权
重
函
数
,
即
作为我们的权重函数,即
作为我们的权重函数,即\lambda(t) = \frac{1}{2 \log \sigma}(\sigma^{2t} - 1)$.
当参数
σ
\sigma
σ变得非常大的时候,其中的先验分布
p
t
=
1
p_{t=1}
pt=1,也就是最终扰动后的数据分布就可以变成一个正太分布:
∫
p
0
(
y
)
N
(
x
;
y
,
1
2
log
σ
(
σ
2
−
1
)
I
)
d
y
≈
N
(
x
;
0
,
1
2
log
σ
(
σ
2
−
1
)
I
)
,
\int p_0(\mathbf{y})\mathcal{N}\bigg(\mathbf{x}; \mathbf{y}, \frac{1}{2 \log \sigma}(\sigma^2 - 1)\mathbf{I}\bigg) d \mathbf{y} \approx \mathbf{N}\bigg(\mathbf{x}; \mathbf{0}, \frac{1}{2 \log \sigma}(\sigma^2 - 1)\mathbf{I}\bigg),
∫p0(y)N(x;y,2logσ1(σ2−1)I)dy≈N(x;0,2logσ1(σ2−1)I),
直观地说,这个SDE通过一个变种函数
1
2
l
o
g
σ
(
σ
2
t
−
1
)
\frac1{2\ log\ \sigma}(\sigma^{2t}-1)
2 log σ1(σ2t−1)帮助我们捕获了高斯扰动的数据变量集合(连续统continuum),即
x
(
t
)
x(t)
x(t)。这个数据变量集合可以帮助我们逐渐将原始数据分布
p
0
p_0
p0变成了一个简单的高斯分布
p
1
p_1
p1,也就是t=1时候的分布。
二、代码实现
1) 对t进行连续采样
# 对时间特征t进行均匀采样
random_t = torch.rand(x.x.shape[0]//30, device=device) * (1. - eps) + eps # 防止采样到0
2)定义权重函数
可以看到,这里定义的权重函数就是作者在上面提到的 λ ( t ) \lambda(t) λ(t)函数。
def marginal_prob_std(t, sigma):
# t = torch.tensor(t, device=device)
return torch.sqrt((sigma ** (2 * t) - 1.) / 2. / np.log(sigma))
3)对数据进行扰动
# 表征时间的特征t, 从0到1上进行均匀采样
random_t = torch.rand(batchsize, device=device) * (1. - eps) + eps # 这里的eps是为了防止采样到t=0
# 构造一个与原始数据结构一样的向量,并在[0,1)上进行均匀采样。
z = torch.randn_like(x.x)
# 利用前面均匀采样的时间特征t,求得权重函数的值,这个权重函数的目的就是为了使得t=1时的扰动数据达到一个正太分布的结果。重复30遍的目的是因为一轮训练中设置的batch_size = 30
std = marginal_prob_std_func(random_t).repeat(1, 30).view(-1, 1)
# 这里将噪声与标准差相乘,
perturbed_x = copy.deepcopy(x)
perturbed_x.x += z * std
4)利用扰动的数据进行训练
需要补充一下,为了训练积分函数模型,目前的目标函数变成了下面这个样子:
E
t
∈
u
(
0
,
T
)
E
p
t
(
x
)
[
λ
(
t
)
∣
∣
∇
x
l
o
g
p
t
(
x
)
−
s
θ
(
x
,
t
)
∣
∣
2
2
]
\mathbb{E}_{t\in u(0,T)}\mathbb{E}_{p_t(x)}[\lambda(t)||\nabla_xlog\ p_t(x)-s_\theta(x,t)||_2^2]
Et∈u(0,T)Ept(x)[λ(t)∣∣∇xlog pt(x)−sθ(x,t)∣∣22]
这里是最基本的目标函数的样子:
E
p
(
x
)
[
∣
∣
∇
x
l
o
g
p
(
x
)
−
s
θ
(
x
)
∣
∣
2
2
]
=
∫
p
(
x
)
∣
∣
∇
x
l
o
g
p
(
x
)
−
s
θ
(
x
)
∣
∣
2
2
d
x
.
\mathbb{E}_{p(x)}[{||\nabla_xlog\ p(x)\ -\ s_\theta(x)||}_2^2]\ =\ \int\ p(x)||\nabla_x\ log\ p(x)\ -\ s_\theta(x)||_2^2dx.
Ep(x)[∣∣∇xlog p(x) − sθ(x)∣∣22] = ∫ p(x)∣∣∇x log p(x) − sθ(x)∣∣22dx.
为了估计这个目标函数,需要如下估计,即使用Score Matching的方法进行估计(Hyvärinen 2005):
可以看到,去估计如下的目标函数是可以达到的。
E p d a t a ( x ) [ 1 2 ∣ ∣ s θ ( x ) ∣ ∣ 2 2 + t r a c e ( ∇ x s θ ( x ) ) ] \mathbb{E}_{p_{data}(x)}[\frac12||s_\theta(x)||_2^2+trace(\nabla_xs_\theta(x))] Epdata(x)[21∣∣sθ(x)∣∣22+trace(∇xsθ(x))]
具体上,体现在代码上,用的是如下的公式:
1
N
∑
i
=
1
N
[
1
2
∣
∣
s
θ
(
x
i
)
∣
∣
2
2
+
t
r
a
c
e
(
∇
x
s
θ
(
x
i
)
)
]
≈
1
N
∑
i
=
1
N
[
1
2
∣
∣
s
θ
(
x
i
)
∣
∣
2
2
+
t
r
a
c
e
(
∇
x
s
θ
(
x
i
)
)
\frac1N\sum^N_{i=1}[\frac12||s_\theta(x_i)||_2^2+trace(\nabla_xs_\theta(x_i))] \\ \approx \frac1N \sum_{i=1}^N [\frac12||s_\theta(x_i)||_2^2+trace(\nabla_xs_\theta(x_i))
N1i=1∑N[21∣∣sθ(xi)∣∣22+trace(∇xsθ(xi))]≈N1i=1∑N[21∣∣sθ(xi)∣∣22+trace(∇xsθ(xi))
# 计算积分函数的值
output = model(perturbed_x, random_t)
# score matching的损失函数,与上式不一致的原因在于,本文的目标函数中还有一个参数\lambda(t),所以表现为如下的形式。
loss_ = torch.mean(torch.sum(((output * std + z)**2).view(batch_size, -1)), dim=-1)
# 一轮训练之后,将score matching的目标函数的结果返回
return loss_
🙋♂️ 我有一个问题,这个目标函数是怎么推理得到的呀? 🤔