Score-matching 从另一个角度理解扩散模型
引入
我们用 p ( x ) p(x) p(x) 来表示真实的图片分布,定义 ∇ x l o g p ( x ) \nabla_xlogp(x) ∇xlogp(x) 即为score,假设我们可以得到每个 x x x处的score,那么我们就可以根据score的方向(梯度)去移动 x x x 为 x ^ \hat{x} x^ ,那么 l o g p ( x ^ ) > l o g p ( x ) logp(\hat{x})>logp(x) logp(x^)>logp(x),迭代去移动 x x x,最终会使得 l o g p ( x ) logp(x) logp(x) 越来越大,意味着x越来越接近真实的图片,我们要做的就是训练一个模型 s θ ( x ) s_{\theta}(x) sθ(x) 尽可能去拟合真实的 ∇ x l o g p ( x ) \nabla_xlogp(x) ∇xlogp(x)
两大问题
第一个问题是manifold hypothesis,这个假说认为现实世界的数据往往集中在嵌入高维空间中的低维流形上,由于梯度 ∇ x l o g p ( x ) \nabla_xlogp(x) ∇xlogp(x)是在高维空间中计算的,当 x x x被限制在低维流形上时,梯度无法定义。
第二个问题是Low Data Density Regions,在低数据密度区域中,数据稀缺会对使得梯度估计不准,另外也会对基于Langevin动力学的MCMC采样造成困难
另外还有一个更根本的问题就是,真实的 ∇ x l o g p ( x ) \nabla_xlogp(x) ∇xlogp(x) 我们根本不知道,既然无法得到这个ground truth,那也就根本没办法训练模型
解决方式
作者发现,通过给原始数据加高斯噪声 ,就可以完美的规避这两个问题。1)由于高斯噪声分布的分布是整个空间,扰动后的数据不会被限制在低维流形上,使得score估计变得明确。2)大量的高斯噪声可以填充原始未扰动数据分布中的低密度区域,因此score匹配可能获得更多的训练信号以改善估计。
作者采用了多级别的噪声干扰,在采样的时候先使用对应于大噪声的score,之后逐渐降低噪声级别
Noise Conditional Score Networks
作者定义了一系列的噪声强度
{
σ
i
}
i
=
1
L
\{ \sigma_i\}_{i=1}^L
{σi}i=1L,其满足
σ
1
σ
2
=
⋯
=
σ
L
−
1
σ
L
>
1
\frac{\sigma_1}{\sigma_2} = \cdots = \frac{\sigma_{L-1}}{\sigma_L} > 1
σ2σ1=⋯=σLσL−1>1,扰动后的数据分布为
q
σ
(
x
)
=
∫
p
data
(
t
)
log
q
σ
(
x
∣
t
)
d
t
w
h
e
r
e
log
q
σ
(
x
∣
t
)
=
N
(
x
∣
t
,
σ
2
I
)
q_{\sigma}(x) = \int p_{\text{data}}(t) \log q_{\sigma}(x | t) dt \\ where \ \ \log q_{\sigma}(x| t)=N(x|t,\sigma^2I)
qσ(x)=∫pdata(t)logqσ(x∣t)dtwhere logqσ(x∣t)=N(x∣t,σ2I)
我们让
σ
1
σ_1
σ1足够大以缓解刚刚讨论的两大问题,而
σ
L
σ_L
σL足够小以最小化对数据的影响,使得其与原始分布尽可能等价。我们关注公式(1)会发现,扰动后的数据分布是在原本的分布上叠加了一个高斯噪声,读到这里的时候很迷惑,因为本来的
∇
x
l
o
g
p
(
x
)
\nabla_xlogp(x)
∇xlogp(x)已经很复杂很难得到了,现在又给这个上边加了一个高斯噪声,那模型不久更没办法训练了吗
按照一步步推导,在给定的
σ
\sigma
σ 下我们的损失函数应该是这样的
ℓ
(
θ
;
σ
)
=
1
2
E
q
σ
(
x
~
)
[
∥
s
θ
(
x
~
,
σ
)
+
∇
x
~
log
q
σ
(
x
~
)
∥
2
2
]
\ell(\theta; \sigma) = \frac{1}{2} \mathbb{E}_{q_{\sigma}(\tilde{x})} \left[ \left\| s_{\theta}(\tilde{x}, \sigma) + \nabla_{\tilde{x}}\log q_{\sigma}(\tilde{x}) \right\|_2^2 \right]
ℓ(θ;σ)=21Eqσ(x~)[∥sθ(x~,σ)+∇x~logqσ(x~)∥22]
阅读论文后发现,作者的损失函数事实上是这样的
ℓ
(
θ
;
σ
)
=
1
2
E
p
data
(
x
)
E
x
~
∼
N
(
x
,
σ
2
I
)
[
∥
s
θ
(
x
~
,
σ
)
+
∇
x
~
log
q
σ
(
x
~
∣
x
)
∥
2
2
]
\ell(\theta; \sigma) = \frac{1}{2} \mathbb{E}_{p_{\text{data}}(x)} \mathbb{E}_{\tilde{x} \sim N(x, \sigma^2 I)} \left[ \left\| s_{\theta}(\tilde{x}, \sigma) + \nabla_{\tilde{x}}\log q_{\sigma}(\tilde{x}| x) \right\|_2^2 \right]
ℓ(θ;σ)=21Epdata(x)Ex~∼N(x,σ2I)[∥sθ(x~,σ)+∇x~logqσ(x~∣x)∥22]
作者认为这两项是等价的,这意味着对于条件分布的score估计的期望等价于对边际分布的score估计。作者并未对此进行证明,在文中一笔带过。
后来的另一篇文章flow matching里,在推导优化向量场的边际分布估计等价于优化每一个条件分布的估计时给出了类似证明,用那边文章中的方法可以证明(2)和(3)的等价性
在确定损失函数是(3)后,事情就变得简单多了,因为 log q σ ( x ∣ t ) = N ( x ∣ t , σ 2 I ) \log q_{\sigma}(x| t)=N(x|t,\sigma^2I) logqσ(x∣t)=N(x∣t,σ2I),故我们可以轻而易举的得到 ∇ x ~ log q σ ( x ~ ∣ x ) = − ( x ~ − x ) / σ 2 \nabla_{\tilde{x}}\log q_{\sigma}(\tilde{x}| x)=-(\tilde{x}-x)/\sigma^2 ∇x~logqσ(x~∣x)=−(x~−x)/σ2
所以在给定$\sigma $时,优化目标就变为
ℓ
(
θ
;
σ
)
=
1
2
E
p
data
(
x
)
E
x
~
∼
N
(
x
,
σ
2
I
)
[
∥
s
θ
(
x
~
,
σ
)
+
x
~
−
x
σ
2
∥
2
2
]
\ell(\theta; \sigma) = \frac{1}{2} \mathbb{E}_{p_{\text{data}}(x)} \mathbb{E}_{\tilde{x} \sim N(x, \sigma^2 I)} \left[ \left\| s_{\theta}(\tilde{x}, \sigma) + \frac{\tilde{x} - x}{\sigma^2} \right\|_2^2 \right]
ℓ(θ;σ)=21Epdata(x)Ex~∼N(x,σ2I)[
sθ(x~,σ)+σ2x~−x
22]
然后,我们将所有
σ
∈
{
σ
i
}
i
=
1
L
\sigma \in \{\sigma_i\}_{i=1}^L
σ∈{σi}i=1L的式(4)结合起来,得到一个统一的目标
L
(
θ
;
{
σ
i
}
i
=
1
L
)
=
1
L
∑
i
=
1
L
λ
(
σ
i
)
ℓ
(
θ
;
σ
i
)
L(\theta; \{\sigma_i\}_{i=1}^L) = \frac{1}{L} \sum_{i=1}^L \lambda(\sigma_i) \ell(\theta; \sigma_i)
L(θ;{σi}i=1L)=L1i=1∑Lλ(σi)ℓ(θ;σi)
为了让所有
{
σ
i
}
i
=
1
L
的
λ
(
σ
i
)
ℓ
(
θ
;
σ
i
)
\{\sigma_i\}_{i=1}^L的\lambda(\sigma_i)\ell(\theta; \sigma_i)
{σi}i=1L的λ(σi)ℓ(θ;σi)值大致处于同一数量级,作者取
λ
(
σ
)
=
σ
2
\lambda(\sigma) = \sigma^2
λ(σ)=σ2
NCSN inference via annealed Langevin dynamics
在采样阶段,作者采用了如下算法
这是在训练好NCSN s θ ( x , σ ) s_{\theta}(x, \sigma) sθ(x,σ) 后,受模拟退火和退火重要性采样的启发,提出的一种采样方法——退火Langevin动力学采样。该方法从某个固定的先验分布(如均匀噪声)初始化样本,然后通过Langevin动力学从 q σ 1 ( x ) q_{\sigma_1}(x) qσ1(x) 采样,步长为 α 1 \alpha_1 α1。接着,用减少的步长 α 2 \alpha_2 α2 从 q σ 2 ( x ) q_{\sigma_2}(x) qσ2(x) 采样,初始样本为上一次模拟的最终样本。以此类推,使用 q σ i − 1 ( x ) q_{\sigma_{i-1}}(x) qσi−1(x) 的最终样本作为 q σ i ( x ) q_{\sigma_i}(x) qσi(x) 的初始样本,逐步调低步长 α i \alpha_i αi,最终从 q σ L ( x ) q_{\sigma_L}(x) qσL(x) 采样,当 σ L ≈ 0 \sigma_L \approx 0 σL≈0 时, q σ L ( x ) q_{\sigma_L}(x) qσL(x) 接近 p data ( x ) p_{\text{data}}(x) pdata(x)。
由于分布 { q σ i } i = 1 L \{q_{\sigma_i}\}_{i=1}^L {qσi}i=1L 都被高斯噪声扰动,其支持集覆盖整个空间,分数定义良好,避免了流形假设的困难。当 σ 1 \sigma_1 σ1 足够大时, q σ 1 ( x ) q_{\sigma_1}(x) qσ1(x) 的低密度区域变小,模式变得不那么孤立,从而提高了分数估计的准确性,加快了Langevin动力学的混合速度。Langevin动力学生成的 q σ 1 ( x ) q_{\sigma_1}(x) qσ1(x) 样本可能位于 q σ 2 ( x ) q_{\sigma_2}(x) qσ2(x) 的高密度区域,这使得 q σ 1 ( x ) q_{\sigma_1}(x) qσ1(x) 的样本成为 q σ 2 ( x ) q_{\sigma_2}(x) qσ2(x) 的良好初始样本。类似地, q σ i − 1 ( x ) q_{\sigma_{i-1}}(x) qσi−1(x) 为 q σ i ( x ) q_{\sigma_i}(x) qσi(x) 提供了良好的初始样本,最终从 q σ L ( x ) q_{\sigma_L}(x) qσL(x) 获得高质量样本。
与DDPM的联系
score-matching中的加噪,难免会让人想起DDPM,两者都对原始数据进行了加噪处理,并在扰动后的数据训练模型,只不过DDPM中拟合的是 ϵ \epsilon ϵ,而score-matching中拟合的是 ∇ x ~ log q σ ( x ~ ) \nabla_{\tilde{x}}\log q_{\sigma}(\tilde{x}) ∇x~logqσ(x~),那么这两者是否有关系呢,是否可以通过某种方式建立联系呢?
如果对DDPM不太熟悉的话,可以先看看我的另一篇文章,简单回顾一下DDPM
这个方式就是
T
w
e
e
d
i
e
F
o
r
m
u
l
a
Tweedie\ Formula
Tweedie Formula:给定观测值Y,我们可以通过观测值 Y加上一个修正项(与噪声方差和观测值的对数概率密度梯度相关)来估计真实信号,特别的,如果扰动方式遵从高斯分布
z
:
N
(
μ
z
,
σ
z
)
z:N(\mu_z,\sigma_z)
z:N(μz,σz),我们有
E
[
μ
z
∣
z
]
=
z
+
σ
z
∇
z
log
p
(
z
)
E[\mu_z|z]=z+\sigma_z \nabla_z\log p(z)
E[μz∣z]=z+σz∇zlogp(z)
在DDPM中,通过马尔可夫链后的分布为
q
(
x
t
∣
x
0
)
=
N
(
x
t
;
α
ˉ
t
x
0
,
(
1
−
α
ˉ
t
)
I
)
q(x_t|x_0) = \mathcal{N} \left( x_t; \sqrt{\bar{\alpha}_t} x_0, (1 - \bar{\alpha}_t) I \right)
q(xt∣x0)=N(xt;αˉtx0,(1−αˉt)I)
根据 Tweedie’s 公式,我们有
E
[
μ
x
t
∣
x
t
]
=
x
t
+
(
1
−
α
ˉ
t
)
∇
x
t
log
p
(
x
t
)
E[\mu_{x_t} | x_t] = x_t + (1 - \bar{\alpha}_t) \nabla_{x_t} \log p(x_t)
E[μxt∣xt]=xt+(1−αˉt)∇xtlogp(xt)
因此,我们可以得到
x
0
=
x
t
+
(
1
−
α
ˉ
t
)
∇
log
p
(
x
t
)
α
ˉ
t
x_0 = \frac{x_t + (1 - \bar{\alpha}_t) \nabla \log p(x_t)}{\sqrt{\bar{\alpha}_t}}
x0=αˉtxt+(1−αˉt)∇logp(xt)
在DDPM中,我们有
x
0
=
x
t
−
1
−
α
ˉ
t
ϵ
0
α
ˉ
t
x_0 = \frac{x_t - \sqrt{1 - \bar{\alpha}_t} \epsilon_0}{\sqrt{\bar{\alpha}_t}}
x0=αˉtxt−1−αˉtϵ0
我们根据一步步推导,最终得到:
∇
x
~
log
q
σ
(
x
~
)
=
∇
log
p
(
x
t
)
=
−
ϵ
0
1
−
α
ˉ
t
\nabla_{\tilde{x}}\log q_{\sigma}(\tilde{x})=\nabla \log p(x_t) = -\frac{\epsilon_0}{\sqrt{1 - \bar{\alpha}_t}}
∇x~logqσ(x~)=∇logp(xt)=−1−αˉtϵ0
也就是说score-matching和DDPM拟合的目标只差了一个常数项,两种方式遵从完全不同的推导理念,最终却殊途同归,得到了一致的拟合目标。另外,仔细观察score-matching中提出的采样方法,也不难发现,DDPM的采样在公式上其实是退火Langevin动力学采样在T设定为1时的特例。