Score-Based Generative Modeling Through Stochastic Differential Equations (Paper reading)

Score-Based Generative Modeling Through Stochastic Differential Equations

Yang Song, Stanford University, ICLR2021, Cited:723, Code: 无, Paper.

1. 前言

这篇文章是关于一种新的生成建模方法,它通过随机微分方程(SDE)将复杂的数据分布平滑地转换为已知的先验分布。通过逆向时间SDE,可以将先验分布转换回数据分布。这种方法依赖于扰动数据分布的时间相关梯度场(即得分),可以使用神经网络准确估计这些得分,并使用数值SDE求解器生成样本。该框架包含了以前在基于得分的生成建模和扩散概率建模中的方法,并允许新的采样过程和新的建模能力。

2. 方法

2.1 SDE推导

随机微分方程SDE是一种微分方程,其中一个或多个项是随机过程,其解本身也是一个随机过程。SDE用于对股价和利率波动等现象进行建模。推导过程如下,假设我们有一个确定性的微分方程:
d x = f ( x ) d t \begin{equation} dx = f(x)dt \end{equation} dx=f(x)dt
其中, x x x是一个函数, f f f是一个已知的函数。我们可以用欧拉法来离散化这个方程,得到:
x t + Δ t − x t = f ( x t ) Δ t \begin{equation} x_{t+\Delta t} - x_t = f(x_t)\Delta t \end{equation} xt+Δtxt=f(xt)Δt
其中, Δ t \Delta t Δt是一个小的时间间隔。如果我们想要在这个方程中加入一些随机性,比如说由于测量误差或者外部干扰等原因,我们可以在右边加上一个噪声项:
x t + Δ t − x t = f ( x t ) Δ t + g ( x t ) Δ t ϵ \begin{equation} x_{t+\Delta t} - x_t = f(x_t)\Delta t + g(x_t)\sqrt{\Delta t}\epsilon \end{equation} xt+Δtxt=f(xt)Δt+g(xt)Δt ϵ
其中, g g g是一个已知的函数, ϵ \epsilon ϵ是一个服从标准正态分布的随机变量。这样,我们就得到了一个离散形式的SDE。如果我们让 Δ t → 0 \Delta t \to 0 Δt0,那么我们就可以得到连续形式的SDE的一般形式:
d x = f ( x ) d t + g ( x ) d W t \begin{equation} dx = f(x)dt + g(x)dW_t \end{equation} dx=f(x)dt+g(x)dWt
本身扩散模型就属于一个随机过程,那么用SDE描述扩散过程便是一个自然的事。

2.2 基于SDE的扩散过程

d x = f ( x , t ) d t + g ( t ) d w \begin{equation} dx = f(x,t)dt + g(t)dw \end{equation} dx=f(x,t)dt+g(t)dw
其中, f ( x , t ) f(x,t) f(x,t)是漂移项,表示确定性的变化; g ( t ) d w g(t)dw g(t)dw是扩散项,表示随机性的变化。 d w dw dw表示维纳过程(也称为布朗运动)的增量。飘逸项是从前一个状态到下一个状态的变化量。扩散项中的 d w dw dw指的是加入的噪声, g ( t ) g(t) g(t)是加噪的强度。

2.3 基于SDE的扩散重建

一下推导来自B站台UP主VictorYuki,将公式5写为离散形式
x t + Δ t = x t + f ( x t , t ) Δ t + g ( t ) Δ t ϵ \begin{equation} x_{t+\Delta t} = x_t + f(x_t,t)\Delta t + g(t)\sqrt{\Delta t }\epsilon \end{equation} xt+Δt=xt+f(xt,t)Δt+g(t)Δt ϵ
这里 ϵ ∼ N ( 0 , 1 ) \epsilon \sim N(0,1) ϵN(0,1) x t + Δ t x_{t+\Delta t} xt+Δt在给定 x t x_{t} xt时的概率分布,当 ϵ \epsilon ϵ是一个标准正态随机变量,在这种情况下, x t + Δ t x_{t+\Delta t} xt+Δt服从一个正态分布,均值为 x t + f ( x t , t ) Δ t x_t + f(x_t,t)\Delta t xt+f(xt,t)Δt,方差为 g 2 ( t ) Δ t I g^{2}(t)\Delta tI g2(t)ΔtI
p ( x t + Δ t ∣ x t ) ∼ N ( f ( x t , t ) Δ t , g 2 ( t ) Δ t I ) \begin{equation} p(x_{t+\Delta t}|x_{t}) \sim N(f(x_t,t)\Delta t , g^{2}(t)\sqrt{\Delta t }I) \end{equation} p(xt+Δtxt)N(f(xt,t)Δt,g2(t)Δt I)
要从 t + Δ t t+\Delta t t+Δt得到 x t x_{t} xt,需要求解:
p ( x t ∣ x t + Δ t ) = p ( x t + Δ t ∣ x t ) p ( x t ) / p ( x t + Δ t ) = p ( x t + Δ t ∣ x t ) e x p { l o g p ( x t ) − l o g p ( x t + Δ t ) } \begin{align} p(x_{t}|x_{t+\Delta t})&=p(x_{t+\Delta t}|x_{t})p(x_{t})/p(x_{t+\Delta t})\\ &= p(x_{t+\Delta t}|x_{t})exp\{logp(x_{t})-logp(x_{t+\Delta t})\} \end{align} p(xtxt+Δt)=p(xt+Δtxt)p(xt)/p(xt+Δt)=p(xt+Δtxt)exp{logp(xt)logp(xt+Δt)}
先将 l o g p ( x t + Δ t ) logp(x_{t+\Delta t}) logp(xt+Δt)泰勒一阶展开得 l o g p ( x t ) + ( x t + Δ t − x t ) ∇ x t l o g p ( x t ) + Δ t ∂ ∂ t l o g p ( x t ) logp(x_{t})+(x_{t+\Delta t}-x_{t}) \nabla _{x_{t}}logp(x_{t})+\Delta t \frac{\partial }{\partial t}logp(x_{t}) logp(xt)+(xt+Δtxt)xtlogp(xt)+Δttlogp(xt),带入公式9并将其第一项展开为高斯分布形式得:
E q . ( 9 ) ∝ e x p { − ∥ x t + Δ t − x t − f ( x t , t ) Δ t ∥ 2 2 2 g 2 ( t ) Δ t − ( x t + Δ t − x t ) ∇ x t l o g p ( x t ) − Δ t ∂ ∂ t l o g p ( x t ) } = e x p { − ∥ ( x t + Δ t − x t ) − ( f ( x t , t ) − g 2 ( t ) ∇ x t l o g p ( x t ) ) Δ t ∥ 2 2 2 g 2 ( t ) Δ t − Δ t ∂ ∂ t l o g p ( x t ) − f 2 ( x t , t ) Δ t g 2 ( t ) + ( f ( x t , t ) − g 2 ( t ) ∇ x t l o g p ( x t ) ) 2 Δ t 2 g 2 ( t ) } {\small \begin{align} Eq. (9) &\propto exp\left \{ -\frac{\left \| x_{t+\Delta t}-x_{t}-f(x_{t},t)\Delta t \right \|_{2}^{2} }{2g^{2}(t)\Delta t} - (x_{t+\Delta t}-x_{t})\nabla_{x_{t}} logp(x_{t}) - \Delta t \frac{\partial }{\partial t}logp(x_{t}) \right \} \\ &=exp\left \{ -\frac{\left \| (x_{t+\Delta t}-x_{t}) -(f(x_{t},t)- g^{2}(t)\nabla_{x_{t}} logp(x_{t}))\Delta t \right \|_{2}^{2} }{2g^{2}(t)\Delta t} -\Delta t \frac{\partial }{\partial t}logp(x_{t}) - \frac{f^{2}(x_{t},t)\Delta t}{g^{2}(t)} + \frac{(f(x_{t},t)-g^{2}(t)\nabla_{x_{t}} logp(x_{t}))^{2}\Delta t}{2g^{2}(t)} \right \} \end{align}} Eq.(9)exp{2g2(t)Δtxt+Δtxtf(xt,t)Δt22(xt+Δtxt)xtlogp(xt)Δttlogp(xt)}=exp{2g2(t)Δt(xt+Δtxt)(f(xt,t)g2(t)xtlogp(xt))Δt22Δttlogp(xt)g2(t)f2(xt,t)Δt+2g2(t)(f(xt,t)g2(t)xtlogp(xt))2Δt}
因为我们是从 x t + Δ t → x t x_{t+\Delta t}\to x_{t} xt+Δtxt,因此我们得到的分布应为都是关于时间 t + Δ t t+\Delta t t+Δt的,把(11)中的 t t t都改写,并当 Δ t → 0 \Delta t \to 0 Δt0时,后面三项:
E q . ( 11 ) ≈ e x p { − ∥ ( x t + Δ t − x t ) − ( f ( x t + Δ t , t + Δ t ) − g 2 ( t + Δ t ) ∇ x t + Δ t l o g p ( x t + Δ t ) ) Δ t ∥ 2 2 2 g 2 ( t + Δ t ) Δ t } {\small \begin{equation} Eq. (11) \approx exp\left \{ -\frac{\left \| (x_{t+\Delta t}-x_{t}) -(f(x_{t+\Delta t},t+\Delta t)- g^{2}(t+\Delta t)\nabla_{x_{t+\Delta t}} logp(x_{t+\Delta t}))\Delta t \right \|_{2}^{2} }{2g^{2}(t+\Delta t)\Delta t}\right \} \end{equation}} Eq.(11)exp{2g2(t+Δt)Δt(xt+Δtxt)(f(xt+Δt,t+Δt)g2(t+Δt)xt+Δtlogp(xt+Δt))Δt22}
可以得到 p ( x t ∣ x t + Δ t ) p(x_{t}|x_{t+\Delta t}) p(xtxt+Δt)的均值为 ( x t + Δ t − x t ) − ( f ( x t + Δ t , t + Δ t ) − g 2 ( t + Δ t ) ∇ x t + Δ t l o g p ( x t + Δ t ) (x_{t+\Delta t}-x_{t}) -(f(x_{t+\Delta t},t+\Delta t)- g^{2}(t+\Delta t)\nabla_{x_{t+\Delta t}} logp(x_{t+\Delta t}) (xt+Δtxt)(f(xt+Δt,t+Δt)g2(t+Δt)xt+Δtlogp(xt+Δt),方差为 2 g 2 ( t + Δ t ) Δ t 2g^{2}(t+\Delta t)\Delta t 2g2(t+Δt)Δt。则采样的连续公式为:
d x = [ f ( x , t ) − g 2 ( t ) ∇ x t l o g p ( x t ) ] d t + g ( t ) d w \begin{equation} dx=[f(x,t)-g^{2}(t)\nabla_{x_{t}}logp(x_{t})]dt+g(t)dw \end{equation} dx=[f(x,t)g2(t)xtlogp(xt)]dt+g(t)dw
离散的过程表示为:
x t − 1 = x t − [ f ( x t , t ) − g 2 ( t ) ∇ x t l o g p ( x t ) + g ( t ) ϵ \begin{equation} x_{t-1}=x_{t}-[f(x_{t},t)-g^{2}(t)\nabla_{x_{t}}logp(x_{t})+g(t) \epsilon \end{equation} xt1=xt[f(xt,t)g2(t)xtlogp(xt)+g(t)ϵ

2.4 Variance Exploring(VE)and Variation Preserving(VP)

SDE在理论上统一了Score-based Model (NCSN)和DDPM,他们分别对应VE and VP。VE扩散过程在扩散过程的每个步骤增加了更多的噪声,导致潜在变量的分布更广。这有助于探索不同的数据分布模式。VP扩散过程在扩散过程的每个步骤中保持方差恒定,从而导致潜在变量的球形高斯分布。这有助于保存信息并避免模式崩溃。这两种方法都旨在提高扩散模型采样的效率和准确性。然而,VE比VP具有一些优势,例如更好的似然估计、更快的收敛和更低的内存消耗。

方法Variance ExploringVariation Preserving
x t = x_{t}= xt= x 0 + σ t ϵ x_{0}+\sigma_{t}\epsilon x0+σtϵ α t ˉ x 0 + 1 − α t ˉ ϵ \sqrt{\bar{\alpha_{t}}}x_{0}+\sqrt{1-\bar{\alpha_{t}}}\epsilon αtˉ x0+1αtˉ ϵ
x t + 1 = x_{t+1}= xt+1= x t + σ t + 1 2 − σ t 2 ϵ x_{t}+\sqrt{\sigma_{t+1}^{2}-\sigma_{t}^{2}}\epsilon xt+σt+12σt2 ϵ 1 − β t + 1 x t + β t + 1 ϵ \sqrt{1-\beta_{t+1}}x_{t}+\sqrt{\beta_{t+1}}\epsilon 1βt+1 xt+βt+1 ϵ

我们现在要求公式6中的 f ( x t , t ) f(x_{t},t) f(xt,t) g ( t ) g(t) g(t),将公式17和公式21分别与公式6比对,就可求出:
VE f ( x t , t ) = 0 f(x_{t},t)=0 f(xt,t)=0 g ( t ) = d d t σ t 2 g(t)=\frac{d}{dt}\sigma_{t}^{2} g(t)=dtdσt2
x t + Δ t = x t + σ t + Δ t 2 − σ t 2 ϵ = x t + ( σ t + Δ t 2 − σ t 2 ) / Δ t Δ t ϵ = x t + Δ σ t 2 / Δ t Δ t ϵ \begin{align} x_{t+\Delta t}&=x_{t}+\sqrt{\sigma_{t+\Delta t}^{2}-\sigma_{t}^{2}}\epsilon\\ &= x_{t} + \sqrt{(\sigma_{t+\Delta t}^{2}-\sigma_{t}^{2})/\sqrt{\Delta t}} \sqrt{\Delta t}\epsilon\\ &=x_{t} + \sqrt{\Delta \sigma_{t}^{2}/\sqrt{\Delta t}} \sqrt{\Delta t}\epsilon \end{align} xt+Δt=xt+σt+Δt2σt2 ϵ=xt+(σt+Δt2σt2)/Δt Δt ϵ=xt+Δσt2/Δt Δt ϵ
VP f ( x t , t ) = − 1 2 β ( t ) x t f(x_{t},t)=-\frac{1}{2}\beta(t)x_{t} f(xt,t)=21β(t)xt g ( t ) = β ( t ) g(t)=\sqrt{\beta(t)} g(t)=β(t) 。我们知道加噪是按照一个序列 { β i } i = 1 T \{\beta_{i} \}_{i=1}^{T} {βi}i=1T进行,令 { β ˉ i = T β i } i = 1 T \{\bar \beta_{i}=T\beta_{i} \}_{i=1}^{T} {βˉi=Tβi}i=1T,则当 T T T趋近于无穷时, { β ˉ i } i = 1 T → β ( t ) , t ∈ [ 0 , 1 ] \{\bar\beta_{i} \}_{i=1}^{T} \to \beta(t),t \in[0,1] {βˉi}i=1Tβ(t),t[0,1],接近一个函数。且有 β ( i T ) = β ˉ i \beta(\frac{i}{T})=\bar \beta_{i} β(Ti)=βˉi,并令 Δ t = 1 T \Delta t=\frac{1}{T} Δt=T1,那么:
x t + 1 = 1 − β ˉ t + 1 T x t + β ˉ t + 1 T ϵ x t + Δ t = 1 − β ( t + Δ t ) Δ t x t + β ( t + Δ t ) Δ t ϵ ≈ ( 1 − 1 2 β ( t + Δ t ) Δ t ) x t + β ( t + Δ t ) Δ t ϵ ≈ x t − 1 2 β ( t ) Δ t x t + β ( t ) Δ t ϵ \begin{align} x_{t+1}&=\sqrt{1-\frac{\bar \beta_{t+1}}{T}}x_{t}+\sqrt{\frac{\bar \beta_{t+1}}{T}}\epsilon \\ x_{t+\Delta t} &= \sqrt{1-\beta(t+\Delta t)\Delta t}x_{t}+\sqrt{\beta(t+\Delta t)\Delta t}\epsilon \\ &\approx (1-\frac{1}{2}\beta(t+\Delta t)\Delta t)x_{t}+\sqrt{\beta(t+\Delta t)}\sqrt{\Delta t}\epsilon \\ &\approx x_{t}-\frac{1}{2}\beta(t)\Delta tx_{t}+\sqrt{\beta(t)}\sqrt{\Delta t}\epsilon \end{align} xt+1xt+Δt=1Tβˉt+1 xt+Tβˉt+1 ϵ=1β(t+Δt)Δt xt+β(t+Δt)Δt ϵ(121β(t+Δt)Δt)xt+β(t+Δt) Δt ϵxt21β(t)Δtxt+β(t) Δt ϵ
至此,我们通过SDE表示出了VE和VP。

2.5 联系

Score-based Model中的score: s θ ( x t , t ) s_{\theta}(x_{t},t) sθ(xt,t)和DDPM中的Denoiser: ϵ θ ( x t , t ) \epsilon_{\theta}(x_{t},t) ϵθ(xt,t)之间有什么联系?在DDPM中, x t ∼ N ( α t ˉ x 0 + 1 − α t ˉ I ) x_{t}\sim N(\sqrt{\bar{\alpha_{t}}}x_{0}+1-\bar{\alpha_{t}}I) xtN(αtˉ x0+1αtˉI),则 p ( x t ) ∝ e x p { − ∣ ∣ x t − α ˉ t x 0 ∣ ∣ 2 2 2 ( 1 − α ˉ t ) } p(x_{t}) \propto exp\{-\frac{||x_{t}-\sqrt{\bar \alpha_{t}}x_{0}||_{2}^{2}}{2(1-\bar \alpha_{t})}\} p(xt)exp{2(1αˉt)∣∣xtαˉt x022},将 p ( x t ) p(x_{t}) p(xt)带入下面公式并求导得到score:
s c o r e = ∇ x t l o g p ( x t ) = − x t − α ˉ t x 0 1 − α ˉ t \begin{equation} score=\nabla_{x_{t}}logp(x_{t})=-\frac{x_{t}-\sqrt{\bar \alpha_{t}}x_{0}}{1-\bar \alpha_{t}} \end{equation} score=xtlogp(xt)=1αˉtxtαˉt x0
在DDPM中, x t = α t ˉ x 0 + 1 − α t ˉ ϵ x_{t}= \sqrt{\bar{\alpha_{t}}}x_{0}+\sqrt{1-\bar{\alpha_{t}}}\epsilon xt=αtˉ x0+1αtˉ ϵ,可以推出 ϵ = x t − α t ˉ x 0 1 − α t ˉ \epsilon=\frac{x_{t}-\sqrt{\bar{\alpha_{t}}}x_{0}}{\sqrt{1-\bar{\alpha_{t}}}} ϵ=1αtˉ xtαtˉ x0。且优化函数为
L t − 1 simple  = E x 0 , ϵ ∼ N ( 0 , I ) [ ∥ ϵ − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) ∥ 2 ] \begin{equation} L_{t-1}^{\text {simple }}=\mathbb{E}_{\mathbf{x}_{0}, \epsilon \sim \mathcal{N}(0, \mathbf{I})}\left[\left\|\epsilon-\epsilon_{\theta}\left(\sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0}+\sqrt{1-\bar{\alpha}_{t}} \epsilon, t\right)\right\|^{2}\right] \end{equation} Lt1simple =Ex0,ϵN(0,I)[ ϵϵθ(αˉt x0+1αˉt ϵ,t) 2]
那么可以得到 ϵ θ ( x t , t ) = x t − α ˉ t x 0 1 − α ˉ t \epsilon_{\theta}(x_{t},t)=\frac{x_{t}-\sqrt{\bar \alpha_{t}}x_{0}}{\sqrt{1-\bar \alpha_{t}}} ϵθ(xt,t)=1αˉt xtαˉt x0,并于公式22对比可以得到他们之间的联系:
s θ ( x t , t ) ≈ − 1 1 − α ˉ t ϵ θ ( x t , t ) \begin{equation} s_{\theta}(x_{t},t) \approx -\frac{1}{\sqrt{1-\bar \alpha_{t}}}\epsilon_{\theta}(x_{t},t) \end{equation} sθ(xt,t)1αˉt 1ϵθ(xt,t)
即通过这个系数就可以完成Score-based Model和DDPM的转换。

VE与VP之间的联系?VE中: x t = x 0 + σ t ϵ x_{t}=x_{0}+\sigma_{t}\epsilon xt=x0+σtϵ,两边同除: x t 1 + σ t 2 = x 0 1 + σ t 2 + σ t 1 + σ t 2 ϵ \frac{x_{t}}{ \sqrt{1+\sigma_{t}^{2}}}=\frac{x_{0}}{\sqrt{1+\sigma_{t}^{2}}} +\frac{\sigma_{t}}{\sqrt{1+\sigma_{t}^{2}}}\epsilon 1+σt2 xt=1+σt2 x0+1+σt2 σtϵ,我们令VP中的 x ˉ t = x t 1 + σ t 2 \bar x_{t}=\frac{x_{t}}{ \sqrt{1+\sigma_{t}^{2}}} xˉt=1+σt2 xt,并令 α ˉ t = 1 1 + σ t 2 \sqrt{\bar \alpha_{t}}=\frac{1}{ \sqrt{1+\sigma_{t}^{2}}} αˉt =1+σt2 1,这样做的意义就构建他们之间的联系,就是通过VE中的 σ t \sigma_{t} σt来获得VP中的加噪序列 α t \alpha_{t} αt,最后 1 − α ˉ t = σ t 1 + σ t 2 \sqrt{1-\bar \alpha_{t}}=\frac{\sigma_{t}}{ \sqrt{1+\sigma_{t}^{2}}} 1αˉt =1+σt2 σt,则VP中: x t = α ˉ t x 0 + 1 − α ˉ t ϵ x_{t}=\sqrt{\bar \alpha_{t}}x_{0}+\sqrt{1-\bar \alpha_{t}}\epsilon xt=αˉt x0+1αˉt ϵ

  • 13
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值