生成模型SDE

SDE

Stochastic Differential Equation,在多个噪声尺度上扰动数据是SMLD和DDPM成功的关键,SDE将其泛化到连续时间域,也即无限噪声尺度。

前向过程

用于数据扰动,训练得分模型。
d x = f ( x , t ) d t + g ( t ) d t dx = f(x, t)dt + g(t)dt dx=f(x,t)dt+g(t)dt

f是x(t)的drift系数,g是x(t)的diffusion系数。

逆向过程

初始化噪声,采样生成样本。
d x = [ f ( x , t ) − g ( t ) 2 Δ x l o g p t ( x ) ] d t + g ( t ) d w ˉ dx = [f(x, t) - g(t)^2\Delta_xlogp_t(x)]dt + g(t)d\bar{w} dx=[f(x,t)g(t)2Δxlogpt(x)]dt+g(t)dwˉ

目标函数

θ ∗ = a r g m i n θ E t { λ ( t ) E x ( 0 ) E x ( t ) ∣ x ( 0 ) [ ∣ ∣ s θ ( x ( t ) , t ) − Δ x ( t ) l o g p σ ( t ) ( x ( t ) ∣ x ( 0 ) ) ∣ ∣ 2 2 ] } \theta^{*} = \mathop{argmin}\limits_{\theta}E_t\{\lambda(t)E_{x(0)}E_{x(t)|x(0)}[||s_\theta(x(t), t)-\Delta_{x(t)}log_{p_\sigma(t)}(x(t)|x(0))||_2^2]\} θ=θargminEt{λ(t)Ex(0)Ex(t)x(0)[∣∣sθ(x(t),t)Δx(t)logpσ(t)(x(t)x(0))22]}

Δ x ~ l o g q ( x ~ ∣ x ) = Δ x ~ l o g 1 2 π σ 2 e x p { − ∣ ∣ x ~ − x ∣ ∣ 2 2 σ 2 } = Δ x ~ ( l o g 1 2 π σ 2 − ∣ ∣ x ~ − x ∣ ∣ 2 2 σ 2 ) = − x ~ − x σ 2 = − z σ \begin{aligned} \Delta_{\widetilde{x}}log q(\widetilde{x}|x) &= \Delta_{\widetilde{x}} log\frac{1}{\sqrt{2\pi \sigma^2}}exp\{-\frac{||\widetilde{x}-x||^2}{2\sigma^2}\} \\ &= \Delta_{\widetilde{x}} (log\frac{1}{\sqrt{2\pi \sigma^2}} - \frac{||\widetilde{x}-x||^2} {2\sigma^2})\\ &=-\frac{\widetilde{x} - x}{\sigma^2} \\ &=-\frac{z}{\sigma} \end{aligned} Δx logq(x x)=Δx log2πσ2 1exp{2σ2∣∣x x2}=Δx (log2πσ2 12σ2∣∣x x2)=σ2x x=σz
所以SDE的loss代码实现为:

if not likelihood_weighting:
	losses = torch.square(score * std[:, None, None, None, None] + z) # 用\sigma^2加权loss
	losses = torch.mean(losses.reshape(losses.shape[0], -1), dim=-1)# [N,]
else:
	g2 = sde.sde(torch.zeros_like(batch, t)[1] ** 2
	losses = torch.square(score + z / std[:, None, None, None])
	losses = torch.mean(losses.reshape(losses.shape[0], -1), dim=-1) * g2
losses = torch.mean(losses)	

if中用的SDE推荐的loss加权方式,else中参考了Maximum Likelihood Training of Score-Based Diffusion Models这篇工作中的loss加权方式,默认使用的是else中的加权方式。

逆向过程求解

s θ s_{\theta} sθ训好后,我们用它构建逆向SDE,通过数值方法生成符合 p 0 p_0 p0数据分布的样本。

通用数值ODE求解器

数值求解器提供了SDE的近似轨迹。存在许多通用的SDE数值求解方法,比如Euler-Maruyama和随机Runge-Kutta方法。
在CIFAR-10数据集上,reverse diffusion samplers比SMLD和DDPM中用的ancestral sampling要略好一些。

Predictor-Corrector采样器

有点复杂,而且巨慢,有需要再看。

Probability Flow ODE

对于所有的扩散过程,存在一个对应的确定性过程,该过程的轨迹和SDE共享边际概率密度 { p t ( x ) } t = 0 T \{p_t(x)\}_{t=0}^T {pt(x)}t=0T。该确定性过程满足一个ODE:
d x = [ f ( x , t ) − 1 2 g ( t ) 2 Δ x l o g p t ( x ) ] d t dx = [f(x, t) - \frac{1}{2}g(t)^2\Delta_xlogp_t(x)]dt dx=[f(x,t)21g(t)2Δxlogpt(x)]dt
我们称该ODE为probability flow ODE。

似然计算

操纵隐式表征

高效采样

VE, VP and SUB-VP SDEs推导过程

SMLD和DDPM的噪声扰动分别是Variance Exploding(VE)和Variance Preserving(VP) SDEs的离散化形式。sub-VP SDEs做为VP SDEs的修正,常常可以使生成的样本质量和似然表现更好。

VE SDEs

假设总计有N个噪声尺度,SMLD的每个扰动核 p σ i ( x i ∣ x 0 ) p_{\sigma_i}(x_i|x_0) pσi(xix0)可从下述马尔可夫链推导出来:

x i = x i − 1 + σ i 2 − σ i − 1 2 z i − 1 , i = 1 , . . . , N (20) x_i = x_{i-1} + \sqrt{\sigma^2_i - \sigma^2_{i-1}}z_{i-1}, i = 1, ..., N \tag{20} xi=xi1+σi2σi12 zi1,i=1,...,N(20)
where, z i − 1 ∼ N ( 0 , I ) , x 0 ∼ p d a t a , σ 0 = 0 z_{i-1}\sim N(\bold{0}, \bold{I)}, x_0 \sim p_{data}, \sigma_0=0 zi1N(0,I),x0pdata,σ0=0, 当 N N N -> ∞ \infty 时,马尔可夫链 { x i } i = 1 N \{x_i\}_{i=1}^N {xi}i=1N变为连续随机过程 { x ( t ) } t = 0 1 \{x(t)\}^1_{t=0} {x(t)}t=01, { σ i } i = 1 N \{\sigma_i\}_{i=1}^N {σi}i=1N变为函数 σ ( t ) \sigma(t) σ(t), z i z_i zi变为 z ( t ) z(t) z(t), 这里我们使用连续时间变量 t ∈ [ 0 , 1 ] t\in[0, 1] t[0,1]做索引,而不是整数 i ∈ { 1 , 2 , . . . , N } i\in\{1, 2, ..., N\} i{1,2,...,N} x ( i N ) = x i x(\frac{i}{N})=x_i x(Ni)=xi, σ ( i N ) = σ i \sigma(\frac{i}{N})=\sigma_i σ(Ni)=σi, z ( i N ) = z i z(\frac{i}{N})=z_i z(Ni)=zi for i = 1, 2, …, N。可用 Δ t = 1 N \Delta t=\frac{1}{N} Δt=N1 t ∈ { 0 , 1 N , . . . , N − 1 N } t \in\{0, \frac{1}{N}, ..., \frac{N-1}{N}\} t{0,N1,...,NN1}重写上式(20):
x ( t + Δ t ) = x ( t ) + σ 2 ( t + Δ t ) − σ 2 ( t ) z ( t ) ≈ x ( t ) + d [ σ 2 ( t ) ] d t Δ t z ( t ) x(t+\Delta t) = x(t) + \sqrt{\sigma^2(t+\Delta t) - \sigma^2(t)} z(t) \approx x(t) + \sqrt{\frac{d[\sigma^2(t)]}{dt}\Delta t} z(t) x(t+Δt)=x(t)+σ2(t+Δt)σ2(t) z(t)x(t)+dtd[σ2(t)]Δt z(t)
这里的增量 Δ t z ( t ) ∼ N ( 0 , Δ t ) \sqrt{\Delta t}z(t)\sim N(0, \Delta t) Δt z(t)N(0,Δt)所构成的随机过程,天然满足维纳过程,进而可得:
d x = d [ σ 2 ( t ) ] d t d w (21) dx = \sqrt{\frac{d[\sigma^2(t)]}{dt}} dw \tag{21} dx=dtd[σ2(t)] dw(21)
即不存在drift量。

VP SDEs

DDPM中用到的扰动核 { p σ i ( x i ∣ x 0 ) } i = 1 N \{p_{\sigma_i}(x_i|x_0)\}_{i=1}^N {pσi(xix0)}i=1N,离散马尔可夫链为:
x i = 1 − β i x i − 1 + β i z i − 1 , i = 1 , . . . , N (22) x_i = \sqrt{1-\beta_i}x_{i-1} + \sqrt{\beta_i}z_{i-1}, i = 1,..., N \tag{22} xi=1βi xi1+βi zi1,i=1,...,N(22)
where, z i − 1 ∼ N ( 0 , I ) z_{i-1}\sim N(\bold{0}, \bold{I)} zi1N(0,I). 为了获得N-> ∞ \infty 时, 该马尔可夫链的极限,我们定义噪声尺度的辅助集合 { β ˉ i = N β i } i = 1 N \{\bar{\beta}_i = N\beta_i\}_{i=1}^N {βˉi=Nβi}i=1N,并重写式(22)如下:
x i = 1 − β ˉ i N x i − 1 + β ˉ i N z i − 1 , i = 1 , . . . , N (23) x_i = \sqrt{1-\frac{\bar{\beta}_i}{N}}x_{i-1} + \sqrt{\frac{\bar{\beta}_i}{N}}z_{i-1}, i = 1,..., N \tag{23} xi=1Nβˉi xi1+Nβˉi zi1,i=1,...,N(23)
当N趋于无穷大时, { β ˉ i = N β i } i = 1 N \{\bar{\beta}_i = N\beta_i\}_{i=1}^N {βˉi=Nβi}i=1N成为以 t ∈ [ 0 , 1 ] t\in[0, 1] t[0,1]为索引的函数 β ( t ) \beta(t) β(t) β ( i N ) = β ˉ i \beta(\frac{i}{N})=\bar{\beta}_i β(Ni)=βˉi x ( i N ) = x i x(\frac{i}{N})=x_i x(Ni)=xi, , z ( i N ) = z i z(\frac{i}{N})=z_i z(Ni)=zi,可用 Δ t = 1 N \Delta t=\frac{1}{N} Δt=N1 t ∈ { 0 , 1 N , . . . , N − 1 N } t \in\{0, \frac{1}{N}, ..., \frac{N-1}{N}\} t{0,N1,...,NN1}重写上式(23):
x ( t + Δ t ) = 1 − β ( t + Δ t ) Δ t x ( t ) + β ( t + Δ t ) Δ t z ( t ) ≈ x ( t ) − 1 2 β ( t + Δ t ) Δ t x ( t ) + β ( t + Δ t ) Δ t z ( t ) ≈ x ( t ) − 1 2 β ( t ) Δ t x ( t ) + β ( t ) Δ t z ( t ) (24) \begin{aligned} x(t+\Delta t) &= \sqrt{1-\beta(t+\Delta t)\Delta t} x(t) + \sqrt{\beta(t+\Delta t)\Delta t}z(t)\\ &\approx x(t) - \frac{1}{2}\beta(t+\Delta t)\Delta t x(t) + \sqrt{\beta(t+\Delta t)\Delta t}z(t)\\ &\approx x(t) - \frac{1}{2}\beta(t)\Delta tx(t) + \sqrt{\beta(t)\Delta t}z(t) \tag{24} \end{aligned} x(t+Δt)=1β(t+Δt)Δt x(t)+β(t+Δt)Δt z(t)x(t)21β(t+Δt)Δtx(t)+β(t+Δt)Δt z(t)x(t)21β(t)Δtx(t)+β(t)Δt z(t)(24)

上式用到了泰勒近似,当 Δ t \Delta t Δt << 1时,上式中的近似相等成立。因此当 Δ \Delta Δ -> 0时,式(24)收敛到下述VP SDE:
d x = − 1 2 β ( t ) x d t + β ( t ) d w (25) dx = - \frac{1}{2}\beta(t)xdt + \sqrt{\beta(t)}dw \tag{25} dx=21β(t)xdt+β(t) dw(25)

t t t-> ∞ \infty 时,VE SDE是个方差爆炸的过程。相比之下,VP SDE过程的方差是有界的,此外,当 p ( x ( 0 ) ) p(x(0)) p(x(0))是单位方差时,对于所有 t ∈ [ 0 , ∞ ) t\in[0, \infty) t[0,),该过程是固定的单位方差。

根据数理基础可得:
d ∑ V P ( t ) d t = β ( t ) ( I − ∑ V P ( t ) ) \frac{d\sum\nolimits_{VP}(t)}{dt} = \beta(t)(I - \sum\nolimits_{VP}(t)) dtdVP(t)=β(t)(IVP(t))
∑ V P ( t ) \sum\nolimits_{VP}(t) VP(t)是VP SDE x ( t ) x(t) x(t)的协方差,解这个ODE可得:
∑ V P ( t ) = I + e ∫ 0 t − β ( s ) d s ( ∑ V P ( 0 ) − I ) \sum\nolimits_{VP}(t) = I + e^{\int_0^t-\beta(s)ds}(\sum\nolimits_{VP}(0) - I) VP(t)=I+e0tβ(s)ds(VP(0)I)
据此可见,给定 ∑ V P ( 0 ) \sum\nolimits_{VP}(0) VP(0) ∑ V P ( t ) \sum\nolimits_{VP}(t) VP(t)总是有界的,此外,当 ∑ V P ( 0 ) = I \sum\nolimits_{VP}(0)=I VP(0)=I时, ∑ V P ( t ) \sum\nolimits_{VP}(t) VP(t)恒等于 I I I

受VP-SDE启发,提出了新的SDE叫sub-VP SDE:
d x = − 1 2 β ( t ) x d t + β ( t ) ( 1 − e − 2 ∫ 0 t β ( s ) d s ) d w (25) dx = - \frac{1}{2}\beta(t)xdt + \sqrt{\beta(t)(1-e^{-2\int_0^t\beta(s)ds})}dw \tag{25} dx=21β(t)xdt+β(t)(1e20tβ(s)ds) dw(25)
VP和sub-VP SDE的期望 E [ x ( t ) ] E[x(t)] E[x(t)]是相同的。方差函数不同:
∑ s u b − V P ( t ) = I + e − 2 ∫ 0 t β ( s ) d s I + e − ∫ 0 t β ( s ) d s ( ∑ s u b − V P ( 0 ) − 2 I ) \sum\nolimits_{sub-VP}(t) = I + e^{-2\int_0^t\beta(s)ds}I+ e^{-\int_0^t\beta(s)ds}(\sum\nolimits_{sub-VP}(0) - 2I) subVP(t)=I+e20tβ(s)dsI+e0tβ(s)ds(subVP(0)2I)
可以有如下发现:

  1. ∑ s u b − V P ( 0 ) = ∑ V P ( 0 ) \sum\nolimits_{sub-VP}(0) = \sum\nolimits_{VP}(0) subVP(0)=VP(0) 并共享 β ( s ) \beta(s) β(s)时,对于所有的 t ≥ 0 t\geq0 t0 ∑ s u b − V P ( t ) ≤ ∑ V P ( t ) \sum\nolimits_{sub-VP}(t) \leq \sum\nolimits_{VP}(t) subVP(t)VP(t)
  2. 如果 l i m t − > ∞ ∫ 0 t β ( s ) d s = ∞ lim_{t->\infty}\int_0^t\beta(s)ds=\infty limt>0tβ(s)ds= l i m t − > ∞ ∑ s u b − V P ( t ) = l i m t − > ∞ ∑ V P ( t ) = I lim_{t->\infty}\sum\nolimits_{sub-VP}(t) = lim_{t->\infty}\sum\nolimits_{VP}(t) = I limt>subVP(t)=limt>VP(t)=I
    从1可知,我们为什么称之为sub-VP SDE, 因为其方差总是被VP SDE限定。

这三个SDE的扰动核 p σ t ( x ( t ) ∣ x ( 0 ) ) p_{\sigma_t}(x(t)|x(0)) pσt(x(t)x(0))如下:

在这里插入图片描述
其实这三个就是代码中的marginal_prob p t ( x ) p_t(x) pt(x),VE SDE计算mean和std非常简单,不做进一步解释。这里讲下(sub-)VP SDE中的积分项,关于 β ( t ) \beta(t) β(t),我们定义了其两端值(0, 1e-4), (1, 2e-2),并且需要满足递增性,这样的函数形式有无数种,代码中选择了最简单的线性函数形式, β ( t ) = ( β 1 − β 0 ) ∗ t + β 0 \beta(t)=(\beta_1-\beta_0)*t + \beta_0 β(t)=(β1β0)t+β0,则其在t时刻的积分值为 1 2 t 2 ( β 1 − β 0 ) + β 0 t \frac{1}{2}t^2(\beta_1-\beta_0) + \beta_0 t 21t2(β1β0)+β0t,据此就能计算出(sub-)VP SDE的mean和std,用于训练时的数据扰动。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值