DiffusionModel-score-baesd diffusion model 基于分数模型的扩散模型 原理+代码

  • Score-Based Generative Modeling through Stochastic Differential Equations
  • Improved Techniques for Training Score-Based Generative Models
  • Generative modeling by estimating gradients of the data distribution
  • High-Resolution lmage Synthesis with Latent Diffusion Models

参考博客:
https://blog.csdn.net/weixin_47748259/article/details/136768651
参考视频: deep_thought

在这里插入图片描述

Score and Naive Score-Based Models

score &score network

在这里插入图片描述

  • 分数函数:通过对分数函数而不是密度函数建模,可以避免难以处理的常数归一化的困难。对于分布为 p ( x ) p(x) p(x)的概率密度函数,其分数的为对数概率密度函数对于输入数据的梯度,表达式如下:
    ∂ ( l o g p ( x ) ) ∂ x = ∇ x l o g p ( x ) \frac{\partial(logp(x))}{\partial x}=\nabla_xlogp(x) x(logp(x))=xlogp(x)

  • 分数网络:用于对分数建模,通常用 s θ s_{\theta} sθ表示,输入为D维度,输出也为D维

    • s θ : R D → R D \mathbf{s}_{\boldsymbol{\theta}}:\mathbb{R}^{D}\to\mathbb{R}^{D} sθ:RDRD
    • 对分数网络进行训练的目的是使得分数网络的输出逼近真实分布 p d a t a ( x ) p_{data}(x) pdata(x)的分数,模型要学习这个分数函数: E [ ∣ ∣ s θ ( x ) − ∇ x l o g p d a t a ( x ) ∣ ∣ 2 2 ] E[||s_\theta\left(x\right)-\nabla_xlogp_{data}\left(x\right)||_2^2] E[∣∣sθ(x)xlogpdata(x)22]

Langevin dynamics

在这里插入图片描述

x ~ t = x ~ t − 1 + ϵ 2 ∇ x log ⁡ p ( x ~ t − 1 ) + ϵ z t , \tilde{\mathbf{x}}_t=\tilde{\mathbf{x}}_{t-1}+\frac{\epsilon}{2}\nabla_{\mathbf{x}}\log p(\tilde{\mathbf{x}}_{t-1})+\sqrt{\epsilon} \mathbf{z}_t, x~t=x~t1+2ϵxlogp(x~t1)+ϵ zt,

  • ϵ > 0 \epsilon>0 ϵ>0: 固定步数
  • π \pi π:先验分布(可以任意)
  • ∇ x log ⁡ p ( x ~ t − 1 ) \nabla_{\mathbf{x}}\log p(\tilde{\mathbf{x}}_{t-1}) xlogp(x~t1):分数函数,训练一个分数网络 s θ s_{\theta} sθ使其趋近于这个分数函数,就可以进行采样和不断迭代了,就是基于分数的生成模型的基本思想
  • x ~ 0   {\tilde{x}_0~} x~0 :从先验分布 π ( x ) \pi(x) π(x)采样来
  • z t ∼ N ( 0 , I ) \mathbf{z}_t\sim\mathcal{N}(0,I) ztN(0,I)
  • x ~ T   {\tilde{x}_T~} x~T :趋近于p(x) 当 ϵ \epsilon ϵ趋于0,T趋于正无穷,这时候可以认为 x ~ T   {\tilde{x}_T~} x~T 就是符合p(x)的一个样本

score matching

在这里插入图片描述

1 2 E p d a t a [ ∥ s θ ( x ) − ∇ x log ⁡ p d a t a ( x ) ∥ 2 2 ] , \frac{1}{2}\mathbb{E}_{p_{\mathrm{data}}}[\|\mathbf{s}_{\boldsymbol{\theta}}(\mathbf{x})-\nabla_{\mathbf{x}}\log p_{\mathrm{data}}(\mathbf{x})\|_{2}^{2}], 21Epdata[sθ(x)xlogpdata(x)22],

  • 上式的损失函数等价于 E p d a t a ( x ) ∣ tr ⁡ ( ∇ x s θ ( x ) ) + 1 2 ∥ s θ ( x ) ∥ 2 2 ∣ \mathbb{E}_{p_{\mathrm{data}}(\mathbf{x})}\left|\operatorname{tr}(\nabla_{\mathbf{x}}\mathbf{s}_{\boldsymbol{\theta}}(\mathbf{x}))+\frac{1}{2}\left\|\mathbf{s}_{\boldsymbol{\theta}}(\mathbf{x})\right\|_{2}^{2}\right| Epdata(x) tr(xsθ(x))+21sθ(x)22

  • 但是求解雅可比矩阵较复杂

Denoising Score Matching(前提,加噪噪声很小)

在这里插入图片描述
用分数匹配的方法估计加噪后的分数分布对应的目标函数:
1 2 E q σ ( x ~ ∣ x ) p d a t a ( x ) [ ∥ s θ ( x ~ ) − ∇ x ~ log ⁡ q σ ( x ~ ∣ x ) ∥ 2 2 ] . \frac{1}{2}\mathbb{E}_{q_{\sigma}(\tilde{\mathbf{x}}|\mathbf{x})p_{\mathrm{data}}(\mathbf{x})}[\|\mathbf{s}_{\boldsymbol{\theta}}(\tilde{\mathbf{x}})-\nabla_{\tilde{\mathbf{x}}}\log q_{\sigma}(\tilde{\mathbf{x}}\mid\mathbf{x})\|_{2}^{2}]. 21Eqσ(x~x)pdata(x)[sθ(x~)x~logqσ(x~x)22].

  • q σ ( x ~ ∣ x ) q_\sigma(\tilde{\mathbf{x}}\mid\mathbf{x}) qσ(x~x):加噪后的x分布

  • q σ ( x ~ ) ≜ ∫ q σ ( x ~ ∣ x ) p d a t a ( x ) d x q_{\sigma}(\tilde{\mathbf{x}})\triangleq\int q_{\sigma}(\tilde{\mathbf{x}}\mid\mathbf{x})p_{\mathrm{data}}(\mathbf{x})\mathrm{d}\mathbf{x} qσ(x~)qσ(x~x)pdata(x)dx

  • 只有加的噪声很小的时候才有 q σ ( x ) ≈ p d a t a ( x ) q_{\sigma}(\mathbf{x})\approx p_{\mathrm{data}}(\mathbf{x}) qσ(x)pdata(x),数据密度较低区域,分数估计不准确,导致郎之万采样更不准确

Noise Conditional Score Networks (NCSN)

  • problem:Denoising Score Matching中,只有加的噪声很小的时候才有 q σ ( x ) ≈ p d a t a ( x ) . q_{\sigma}(\mathbf{x})\approx p_{\mathrm{data}}(\mathbf{x}). qσ(x)pdata(x).,导致数据密度较低区域,分数估计不准确,导致郎之万采样更不准确,加大噪声会破坏数据分布
  • solution:使用不同大小(量级)的噪声对数据进行扰动;用同一个条件分数网络去估计不同情况下的分数;首先去生成噪声加的比较大的情况下分布的分数,然后逐渐减小加噪的量级,最终噪声减小到0,然后再用郎之万采样方法得到符合正确的原始分布的分数。

Definition of NCSN

噪声:定义一组几何级数序列 { σ i } i = 1 L , σ i > 0 \{\sigma_i\}_{i=1}^L,\sigma_i>0 {σi}i=1Lσi>0 σ 1 σ 2 = σ 2 σ 3 = . . . = σ L − 1 σ L > 1 \frac{\sigma_1}{\sigma_2}=\frac{\sigma_2}{\sigma_3}=...=\frac{\sigma_{L-1}}{\sigma_L}>1 σ2σ1=σ3σ2=...=σLσL1>1这样就可以使得最开始的噪声水平足够大以能够充分“填充"低概率密度区域;而最后的噪声水平足够小以获得对原数据分布良好的近似,避免过度扰动。 σ 1 {\sigma_1} σ1较大, σ L {\sigma_L} σL较小

  • 扰动噪声数据分布: q σ ( x ) = ∫ p d a t a ( x ) N ( x ∣ t , σ 2 I ) d t q_\sigma\left(x\right)=\int p_{data}\left(x\right)N(x|t,\sigma^2I)dt qσ(x)=pdata(x)N(xt,σ2I)dt
  • NCSN的目标是训练一个条件分数网络来估计扰动数据的分布,即 ∀ σ ∈ { σ i } i = 1 L : s θ ( x , σ ) ≈ ∇ x log ⁡ q σ ( x ) \forall\sigma\in\{\sigma_{i}\}_{i=1}^{L}:\mathbf{s}_{\boldsymbol{\theta}}(\mathbf{x},\sigma)\approx\nabla_{\mathbf{x}}\operatorname{log}q_{\sigma}(\mathbf{x}) σ{σi}i=1L:sθ(x,σ)xlogqσ(x)
  • s θ ( x , σ ) \mathbf{s}_{\boldsymbol{\theta}}(\mathbf{x},\sigma) sθ(x,σ):条件噪声网络
  • NCSN和DSM(denoising score matching)不同,DSM是无条件的分数网络,每次加少量的噪声,而NCSN是以噪声为条件。

在这里插入图片描述

learning NCSN via score matching

NCSN采用的噪声条件U-Net结构。对U-Net结构进行了如下优化:

  • 在其中加入了空洞卷积(dilated/atrous convolution)和以噪声为条件的实例归一化(conditional instance normalization++) 。

  • 另外,即使对于同一个像素点,在不同的噪声强度下也要对应估计出不同的分数,于是,模型还需要将噪声强度也作为输入。

  • 由于生成的图片和原图一样大小,也就是每个像素点都需要由朗之万动力学采样生成,因此模型对于每个像素点都要估计其对应的分数。也就是说,从张量的角度来看,网络的输出(分数)要和输入图像的形状(shape)一致。结合这种输入输出特征图的特征,U-Net恰好可以实现这种任务。

故NCSN训练的目标函数:
ℓ ( θ ; σ ) ≜ 1 2 E p d a t a ( x ) E x ~ ∼ N ( x , σ 2 I ) [ ∥ s θ ( x ~ , σ ) + x ~ − x σ 2 ∥ 2 2 ] . \ell(\boldsymbol{\theta};\sigma)\triangleq\frac{1}{2}\mathbb{E}_{p_{\mathrm{data}}(\mathbf{x})}\mathbb{E}_{\tilde{\mathbf{x}}\sim\mathcal{N}(\mathbf{x},\sigma^{2}I)}\bigg[\bigg\|\mathbf{s}_{\boldsymbol{\theta}}(\tilde{\mathbf{x}},\sigma)+\frac{\tilde{\mathbf{x}}-\mathbf{x}}{\sigma^{2}}\bigg\|_{2}^{2}\bigg]. (θ;σ)21Epdata(x)Ex~N(x,σ2I)[ sθ(x~,σ)+σ2x~x 22].

  • 令噪声分布为高斯分布 q σ ( x ~ ∣ x ) = N ( x ~ ∣ x , σ 2 I ) q_{\sigma}(\tilde{\mathbf{x}}\mid\mathbf{x}) = \mathcal{N}(\tilde{\mathbf{x}}\mid\mathbf{x},\sigma^{2}I) qσ(x~x)=N(x~x,σ2I)(前提)
  • 故有 ∇ x ~ log ⁡ q σ ( x ~ ∣ x ) = − ( x ~ − x ) / σ 2 \nabla_{\tilde{\mathbf{x}}}\log q_{\sigma}(\tilde{\mathbf{x}}\mid\mathbf{x})=-(\tilde{\mathbf{x}}-\mathbf{x})/\sigma^{2} x~logqσ(x~x)=(x~x)/σ2从而可以得到目标函数

L ( θ ; { σ i } i = 1 L ) ≜ 1 L ∑ i = 1 L λ ( σ i ) ℓ ( θ ; σ i ) , \mathcal{L}(\boldsymbol{\theta};\{\sigma_{i}\}_{i=1}^{L})\triangleq\frac{1}{L}\sum_{i=1}^{L}\lambda(\sigma_{i})\ell(\boldsymbol{\theta};\sigma_{i}), L(θ;{σi}i=1L)L1i=1Lλ(σi)(θ;σi),

  • 这个目标函数由上述目标函数加上权重系数整合而来

  • λ ( σ i ) = σ 2 \lambda(\sigma_{i})={\sigma^2} λ(σi)=σ2为噪声权重,这样设置使 λ ( σ ) ℓ ( θ , σ ) \lambda(\sigma)\ell(\boldsymbol{\theta,\sigma}) λ(σ)(θ,σ)不依赖于 σ \sigma σ且量级为1,因此能保证模型在训练时去除噪声条件的影响。

NCSN inferrence via anneal Langvin dynamics

在这里插入图片描述

  • 首先,从一个随机分布中采样得到初始样本 x 0 ~ \tilde{\mathbf{x_0}} x0~(先验分布or随机分布)
  • 步长 α i = ε σ i 2 σ L 2 \alpha_{i}=\varepsilon\frac{\sigma_{i}^{2}}{\sigma_{L}^{2}} αi=εσL2σi2,从最大的噪声级别 σ 1 \sigma_1 σ1开始朗之万动力学采样直至最小的噪声级别 σ L \sigma_L σL
  • 接着,在每个噪声级别下,一共进行步数为T步的迭代朗之万动力学采样 x ~ t ← x ~ t − 1 + α i 2 s θ ( x ~ t − 1 , σ i ) + α i z t \tilde{\mathbf{x}}_{t}\leftarrow\tilde{\mathbf{x}}_{t-1}+\frac{\alpha_{i}}{2}\mathbf{s}_{\boldsymbol{\theta}}(\tilde{\mathbf{x}}_{t-1},\sigma_{i})+\sqrt{\alpha_{i}} \mathbf{z}_{t} x~tx~t1+2αisθ(x~t1,σi)+αi zt
  • 该噪声级别最后一步采样生成的样本 q σ i − 1 ( x ) q_{\sigma_{i-1}}(\mathbf{x}) qσi1(x)会作为下一个噪声级别 q σ i ( x ) q_{\sigma_{i}}(\mathbf{x}) qσi(x)的初始样本;
  • 最后,待所有噪声级别的朗之万动力学采样过程均完成时,就得到了最终的生成结果。

Score-based generative modeling with stochastic differential equations(SDEs)

之前我们用的是有限次数的加噪,现在将噪声量级推广为无穷,可以得到更一般的加噪,这个过程可以用一个SDE来表示在这里插入图片描述
如果x[t]是一个连续量(加噪变成无数次),那么扩散过程可以用一个SDE去表达;布朗运动具有增量独立性、增量服从高斯分布、轨迹连续;随机微分方程指的是微分方程中含有随机参数或随机过程或随机初始值或随机边界值,这里的w随机性使得SDE成立):
d x = f ( x , t ) d t + g ( t ) d w dx=f(x,t)dt+g(t)dw dx=f(x,t)dt+g(t)dw

  • f:漂移系数(drift coefficient)
  • g:扩散系数(diffusion coefficient)
  • w:随机的标准的布朗运动的量(standard Brownian motion)
  • x 0 ∼ p 0 , x T ∼ p T x_0 \sim p_0,x_T \sim p_T x0p0,xTpT:这里的 p 0 p_0 p0 p T p_T pT可以看成SDE边界
  • 可以看成 x t + Δ t − x t = f t ( x t ) Δ t + g t Δ t ε , ε ∼ N ( 0 , I ) \boldsymbol{x}_{t+\Delta t}-\boldsymbol{x}_t=\boldsymbol{f}_t(\boldsymbol{x}_t)\Delta t+g_t\sqrt{\Delta t}\boldsymbol{\varepsilon},\quad\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0},\boldsymbol{I}) xt+Δtxt=ft(xt)Δt+gtΔt ε,εN(0,I) Δ t \Delta t Δt趋于0的极限

Reversing the SDE for sample generation

d x = [ f ( x , t ) − g 2 ( t ) ∇ x log ⁡ p t ( x ) ] d t + g ( t ) d w ˉ d\mathbf{x}=[\mathbf{f}(\mathbf{x},t)-g^{2}(t)\nabla_{\mathbf{x}}\log p_{t}(\mathbf{x})]dt+g(t)d\mathbf{\bar{w}} dx=[f(x,t)g2(t)xlogpt(x)]dt+g(t)dwˉ

  • reverse-time SDE

  • w ˉ \bar{w} wˉ:a Brownian motion in the reverse time firection逆时间对应的布朗运动

  • d t d_t dt:infinitesimal negative time step负时间步

  • ∇ x log ⁡ p t ( x ) \nabla_{\mathbf{x}}\log p_{t}(\mathbf{x}) xlogpt(x):score function(跟时间相关的一个分数函数)

在这里插入图片描述
逆向的 SDE 中如果知道了 ∇ x log ⁡ p t ( x ) \nabla_{\mathbf{x}}\log p_{t}(\mathbf{x}) xlogpt(x)就可以完成重建,我们需要用神经网络学一个函数 s θ ( x , t ) s_{\theta}(x,t) sθ(x,t),使它能够直接估计 ∇ x log ⁡ p t ( x ) \nabla_{\mathbf{x}}\log p_{t}(\mathbf{x}) xlogpt(x)

min ⁡ θ E t ∼ U ( 0 , T ) [ λ ( t ) E x ( 0 ) ∼ p 0 ( x ) E x ( t ) ∼ p 0 t ( x ( t ) ) ∣ x ( 0 ) ) [ ∥ s θ ( x ( t ) , t ) − ∇ x ( t ) log ⁡ p 0 t ( x ( t ) ∣ x ( 0 ) ) ∥ 2 2 ] ] , \operatorname*{min}_{\theta}\mathbb{E}_{t\sim\mathcal{U}(0,T)}[\lambda(t)\mathbb{E}_{\mathbf{x}(0)\sim p_{0}(\mathbf{x})}\mathbf{E}_{\mathbf{x}(t)\sim p_{0t}(\mathbf{x}(t))|\mathbf{x}(0))}[\|s_{\theta}(\mathbf{x}(t),t)-\nabla_{\mathbf{x}(t)}\log p_{0t}(\mathbf{x}(t)\mid\mathbf{x}(0))\|_{2}^{2}]], θminEtU(0,T)[λ(t)Ex(0)p0(x)Ex(t)p0t(x(t))x(0))[sθ(x(t),t)x(t)logp0t(x(t)x(0))22]],

  • t服从均匀分布[0,T],可以归一化为[0,1]
  • x ( 0 ) ∼ p 0 ( x ) {\mathbf{x}(0)\sim p_{0}(\mathbf{x})} x(0)p0(x):取的 p 0 ( x ) p_0(x) p0(x)中的独立同分布的样本, x ( 0 ) x(0) x(0)就是对应的训练样本
  • x ( t ) ∼ p 0 t ( x ( t ) ) ∣ x ( 0 ) ) x(t)\sim p_{0t}(\mathbf{x}(t))|\mathbf{x}(0)) x(t)p0t(x(t))x(0)):条件分布
  • λ \lambda λ= 1 E [ ∥ ∇ x log ⁡ p 0 t ( x ( t ) ∣ x ( 0 ) ) ∥ 2 2 ] \frac{1}{\mathbb{E}[\|\nabla_{\mathbf{x}}\log p_{0t}(\mathbf{x}(t)|\mathbf{x}(0))\|_{2}^{2}]} E[xlogp0t(x(t)x(0))22]1:如果是高斯分布这个对应的就是方差,这样设置为了使噪声的量级一致

tips on designing socre-based models

  • 输入输出都是一个维度
  • 时间信息以条件的方式传入进去:
    ω ∼ N ( 0 , s 2 I ) ( f i x e d ) [ sin ⁡ ( 2 π ω t ) ; cos ⁡ ( 2 π ω t ) ] ( 表示时间编码 ) \omega\sim\mathcal{N}(\mathbf{0},s^{2}\mathbf{I})(fixed) \\ [\sin(2\pi\omega t);\cos(2\pi\omega t)](表示时间编码) ωN(0,s2I)(fixed)[sin(2πωt);cos(2πωt)](表示时间编码)
  • 可以将训练好的unet乘以系数 1 / E [ ∥ ∇ x log ⁡ p 0 t ( x ( t ) ∣ x ( 0 ) ) ∥ 2 2 ] 1/\sqrt{\mathbb{E}[\|\nabla_x\log p_{0t}(\mathbf{x}(t)\mid\mathbf{x}(0))\|_2^2]} 1/E[xlogp0t(x(t)x(0))22] ,使其的L2范数接近于 E [ ∥ ∇ x log ⁡ p 0 t ( x ( t ) ∣ x ( 0 ) ) ] ∥ 2 \mathbb{E}[\|\nabla_\mathbf{x}\log p_{0t}(\mathbf{x}(t)\mid\mathbf{x}(0))]\|_2 E[xlogp0t(x(t)x(0))]2
  • 采样的时候可以用EMA方法(exponential moving average)

Pytorch Code Implementations

Define the U-net based Score Prediction Network

#@title Defining a time-dependent score-based model (double click to expand or collapse)

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import functools

class TimeEncoding(nn.Module):
    """ 用于对时间进行特定傅里叶编码 """
    
    def __init__(self, embed_dim, scale=30.):
        super().__init__()
        # Randomly sample weights during initialization, These weights are fixed
        #during optimization and are not trainable.
        self.w= nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
    
    def forward(self,x):
        x_proj = x[:, None] * self.w[None, :] * 2 * np.pi
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)

class Dense(nn.Module):
    """ A fully connected layer that reshapes outputs to feature maps. """
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.dense = nn.Linear(input_dim, output_dim)
        
    def forward(self, x):
        return self.dense(x)[..., None, None]  # 对输出扩充了最后两个维度
    
class ScoreNet(nn.Module):
    """ 基于U-net的时间依赖的分数估计模型 """
    
    def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256):
        """ Initialize a time-dependent score-based network.
        
        Args:
            marginal_prob_std: A function that takes time t and gives the standard
                deviation of the perturbation kernel p_{0t}(x(t)|x(0)).
            channels: The number of channels for feature maps of each resolution.
            embed_dim: The dimensionality of Gaussian random feature embeddings.
        """
        super().__init__()
        # Gaussian random feature embedding layer for time
        self.embed = nn.Sequential(TimeEncoding(embed_dim=embed_dim),
                    nn.Linear(embed_dim, embed_dim))  # 时间编码层
        
        # U-net的编码器部分, 空间不断减小,通道不断增大
        self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)  # to skip connection
        self.dense1 = Dense(embed_dim, channels[0])
        self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
        self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)  # to skip connection
        self.dense2 = Dense(embed_dim, channels[1])
        self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])
        self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)  # to skip connection
        self.dense3 = Dense(embed_dim, channels[2])
        self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
        self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)  # to skip connection
        self.dense4 = Dense(embed_dim, channels[3])
        self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])
        
        # U-net的解码器部分,空间不断增大,通道不断减小,并且有来自编码器部分的skip connection
        self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)
        self.dense5 = Dense(embed_dim, channels[2])
        self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])
        self.tconv3 = nn.ConvTranspose2d(channels[2] + channels[2], channels[1], 3, stride=2, bias=False, output_padding=1)
        self.dense6 = Dense(embed_dim, channels[1])
        self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
        self.tconv2 = nn.ConvTranspose2d(channels[1] + channels[1], channels[0], 3, stride=2, bias=False, output_padding=1)
        self.dense7 = Dense(embed_dim, channels[0])
        self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
        
        # The swish activation function
        self.act = lambda x: x * torch.sigmoid(x)
        
        self.tconv1 = nn.ConvTranspose2d(channels[0] + channels[0], 1, 3, stride=1)
        self.marginal_prob_std = marginal_prob_std
        
    def forward(self, x, t):
        # 对时间t进行编码
        embed = self.act(self.embed(t))
        
        # 编码器部分前向计算
        h1 = self.conv1(x)
        h1 += self.dense1(embed)  # 注入时间t
        h1 = self.gnorm1(h1)
        h1 = self.act(h1)
        h2 = self.conv2(h1)
        h2 += self.dense2(embed)  # 注入时间t
        h2 = self.gnorm2(h2)
        h2 = self.act(h2)
        h3 = self.conv3(h2)
        h3 += self.dense3(embed)  # 注入时间t
        h3 = self.gnorm3(h3)
        h3 = self.act(h3)
        h4 = self.conv4(h3)
        h4 += self.dense4(embed)  # 注入时间t
        h4 = self.gnorm4(h4)
        h4 = self.act(h4)
        
        # 解码器部分前向计算
        h = self.tconv4(h4)
        h += self.dense5(embed)  # 注入时间t
        h = self.tgnorm4(h)
        h = self.act(h)
        h = self.tconv3(torch.cat([h, h3], dim=1))  # skip connection
        h += self.dense6(embed)  # 注入时间t
        h = self.tgnorm3(h)
        h = self.act(h)
        h = self.tconv2(torch.cat([h, h2], dim=1))  # skip connection
        h += self.dense7(embed)  # 注入时间t
        h = self.tgnorm2(h)
        h = self.act(h)
        h = self.tconv1(torch.cat([h, h1], dim=1))  # skip connection
        
        # Normalize output
        h = h / self.marginal_prob_std(t)[:, None, None, None]  # 目的是希望预测的分数的二阶范数逼近于真实分数的二阶范数
        
        return h
device = 'cuda:0'  # cuda or cpu

def marginal_prob_std(t, sigma):
    """ 计算任意t时刻的扰动后条件高斯分布的标准差 """
    
    t = torch.tensor(t, device=device)
    return torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma))

def diffusion_coeff(t, sigma):
    """ 计算任意t时刻的扩散系数,本例定义的SDE没有漂移系数 """
    
    return torch.tensor(sigma**t, device=device)

sigma = 25.0
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)  # 构建无参函数
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)  # 构建无参函数
def loss_fn(model, x, marginal_prob_std, eps=1e-5):
    """ The loss function for training score-based generative models.
    
    Args:
        model: A Pytorch model instance that represents a 
            time-dependent score-based model.
        x: A mini-batch of training data.
        marginal_prob_std: A function that gives the standard deviation of 
            the perturbation kernel.
        eps: A tolerance value for numerical stability.
    """
    # Step1 从[0.00001, 0.9999]中随机生成batchsize个浮点型t
    random_t = torch.rand(x.shape[0], device=x.device) * (1. -eps) + eps
    
    # Step2 基于重参数技巧采样出分布p_t(x)的一个随机样本perturbed_x
    z = torch.randn_like(x)
    std = marginal_prob_std(random_t)
    perturbed_x = x + z * std[:, None, None, None]
    
    # Step3 将当前的加噪样本和时间输入到Score Network中预测出分数score
    score = score_model(perturbed_x, random_t)
    
    # Step4 计算score matching loss
    loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3)))
    return loss
from copy import deepcopy

class EMA(nn.Module):
    def __init__(self, model, decay=0.9999, device=None):
        super(EMA, self).__init__()
        # make a copy of the model for accumulating moving average of weights
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay
        self.device = device
        if self.device is not None:
            self.module.to(device=device)
            
    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                if self.device is not None:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(update_fn(ema_v, model_v))
                
    def update(self, model):
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
        
    def set(self, model):
        self._update(model, update_fn=lambda e, m: m)

Training score-based model on MNIST data

#@title Training (double click to expand or collapse)

import torch
import functools
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
import tqdm


score_model = torch.nn.DataParallel(ScoreNet(marginal_prob_std=marginal_prob_std_fn))
score_model = score_model.to(device)

n_epochs = 50  #@param {'type': 'integer'}
## size of a mini-batch
batch_size = 32  #@param {'type': 'integer'}
## learning rate
lr = 1e-4  #@param {'type': 'number'}

dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

optimizer = Adam(score_model.parameters(), lr=lr)
tqdm_epoch = tqdm.tqdm(range(n_epochs))

ema = EMA(score_model)
for epoch in tqdm_epoch:  # 训练速度:MacCPU 250s/Epoch, Kaggle GPU:30s/Epoch,GoogleColab:35s/Epoch,TPU 未测
    avg_loss = 0.
    num_items = 0
    for x, y in data_loader:
        x = x.to(device)
        loss = loss_fn(score_model, x, marginal_prob_std_fn)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ema.update(score_model)
        avg_loss += loss.item() * x.shape[0]
        num_items += x.shape[0]
    
    print('Average ScoreMatching Loss: {:5f}'.format(avg_loss / num_items))
    #torch.save(score_model.state_dict(), f'ckpt_{epoch}.pth')
    torch.save(score_model.state_dict(), f'ckpt.pth')
        
## The number of sampling steps.
num_steps = 500
def euler_sampler(score_model,
                  marginal_prob_std,
                  diffusion_coeff,
                  batch_size=64,
                  num_steps=num_steps,
                  device='cuda:0',
                  eps=1e-3):
    
    # Step1 定义初始时间1和先验分布的随机样本
    t = torch.ones(batch_size, device=device)
    init_x = torch.randn(batch_size, 1, 28, 28, device=device) \
        * marginal_prob_std(t)[:, None, None, None]
    
    # Step2 定义采样的你时间网格以及每一步的时间步长
    time_steps = torch.linspace(1., eps, num_steps, device=device)
    step_size = time_steps[0] - time_steps[1]
    
    # Step3 根据欧拉算法来求解逆时间SDE
    x = init_x
    with torch.no_grad():
        for time_step in tqdm.tqdm(time_steps):
            batch_time_step = torch.ones(batch_size, device=device) * time_step
            g = diffusion_coeff(batch_time_step)
            mean_x = x + (g**2)[:, None, None, None] * score_model(x, batch_time_step) * step_size
            x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x)
            
    # Step4 取最后一步的期望值作为生成的样本
    return mean_x
signal_to_noise_ratio = 0.16
num_steps = 500

def pc_sampler(score_model,
               marginal_prob_std,
               diffusion_coeff,
               batch_size=64,
               num_steps=num_steps,
               snr=signal_to_noise_ratio,  # 多出来的一项
               device='cuda:0',
               eps=1e-3):
    """ Generate samples from score-based models with Predictor-Corrector method.
    
    Args:
        score_model: A PyTorch model that represents the time-dependent score-based model.
        marginal_prob_std: A function that gives the standard deviation
            of the perturbation kernel.
        diffusion_coeff: A function that gives the diffusion coefficient
            of the SDE.
        batch_size: The number of samplers to generate by calling this function once.
        num_steps: The number of sampling steps.
            Equivalent to the number of discretized time steps.
        device: 'cuda' for running on GPUs. and 'cpu' for running on CPUs.
        eps: The smallest time step for numerical stability.
        
    Returns:
        Samples.
    """
    # Step1 定义初始时间1和先验分布的随机样本
    t = torch.ones(batch_size, device=device)
    init_x = torch.randn(batch_size, 1, 28, 28, device=device) * marginal_prob_std(t)[:, None, None, None]
    
    # Step2 定义采样的你时间网格以及每一步的时间步长
    time_steps = torch.linspace(1., eps, num_steps, device=device)
    step_size = time_steps[0] - time_steps[1]
    
    # Step3 重复交替进行朗之万采样和逆时间SDE的欧拉数值求解
    x = init_x
    with torch.no_grad():
        for time_step in tqdm.tqdm(time_steps):
            batch_time_step = torch.ones(batch_size, device=device) * time_step
            
            # Corrector step (Langevin MCMC)
            grad = score_model(x, batch_time_step)
            grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
            noise_norm = np.sqrt(np.prod(x.shape[1:]))
            langevin_step_size = 2 * (snr * noise_norm / grad_norm) **2
            #print(f"langevin_step_size={langevin_step_size}")
            
            for _ in range(10):
                x = x + langevin_step_size * grad + torch.sqrt(2 * langevin_step_size) * torch.randn_like(x)  # 朗之万采样迭代公式
                grad = score_model(x, batch_time_step)
                grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
                noise_norm = np.sqrt(np.prod(x.shape[1:]))
                langevin_step_size = 2 * (snr * noise_norm / grad_norm)**2
                #print(f"langevin_step_size={langevin_step_size}")
            
            # Predictor step (Euler-Maruyama)
            g = diffusion_coeff(batch_time_step)
            x_mean = x + (g**2)[:, None, None, None] * score_model(x, batch_time_step) * step_size
            x = x_mean + torch.sqrt(g**2 * step_size)[:, None, None, None] * torch.randn_like(x)
        
    # Step4 取最后一步的欧拉数值求解的期望值作为最终生成的样本
    return x_mean
from scipy import integrate
## The error tolerance for the black-box ODE solver
error_tolerance=1e-5 #param {'type':'number'}
def ode_sampler(score_model,
                marginal_prob_std,
                diffusion_coeff,
                batch_size=64,
                atol=error_tolerance,
                rtol=error_tolerance,
                device='cuda:0',
                z=None,
                eps=1e-3):
    
    # step1 定义初始时间1和初始值x
    t = torch.ones(batch_size, device=device)
    if z is None:
        init_x = torch.randn(batch_size, 1, 28, 28, device=device) \
            * marginal_prob_std(t)[:, None, None, None]
    else:
        init_x = z
        
    shape = init_x.shape

    # step2 定义分数预测函数和常微分函数
    def score_eval_wrapper(sample, time_steps):
        """A wrapper of the score-based model for use by the ODE solver."""
        sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)
        time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape((sample.shape[0],))
        with torch.no_grad():
            score = score_model(sample, time_steps)
        return score.cpu().numpy().reshape((-1,)).astype(np.float64)
    
    def ode_func(t, x):
        """The ODE function for use by the ODE solver."""
        time_steps = np.ones((shape[0],)) * t
        g = diffusion_coeff(torch.tensor(t)).cpu().numpy()
        return -0.5 * (g**2) * score_eval_wrapper(x, time_steps)
    
    # step3 调用常微分求解算子来解出t=eps时刻的值,即预测的样本
    res = integrate.solve_ivp(ode_func, (1., eps), init_x.reshape(-1).cpu().numpy(), rtol=rtol, atol=atol, method='RK45')
    print(f"Number of function evaluations: {res.nfev}")
    
    x = torch.tensor(res.y[:, -1], device=device).reshape(shape)
    
    return x

导入训练好的MNIST模型并对比不同的采样算法

from torchvision.utils import make_grid
import time

## Load the pre-trained checkpoint from disk.
device = 'cuda:0'  # cuda
ckpt = torch.load('ckpt.pth', map_location=device)
score_model.load_state_dict(ckpt)

sample_batch_size = 64
for sampler in [euler_sampler, pc_sampler, ode_sampler]:
#sampler = pc_sampler  # ['euler_sampler', 'pc_sampler', 'ode_sampler']

    t1 = time.time()
    ## Generate samples using the specified sampler.
    samples = sampler(score_model,
                      marginal_prob_std_fn,
                      diffusion_coeff_fn,
                      sample_batch_size,
                      device=device)
    t2 = time.time()
    print(f"{str(sampler)}采样耗时{t2-t1}s")

    ## Sample visualization.
    samples = samples.clamp(0.0, 1.0)
    %matplotlib inline
    import matplotlib.pyplot as plt
    sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size)))

    plt.figure(figsize=(6,6))
    plt.axis('off')
    plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
    plt.show()
def prior_likelihood(z, sigma):
    """ 先验高斯分布的对数似然 """
    shape =z.shape
    N = np.prod(shape[1:])
    return -N / 2.* torch.log(2*np.pi*sigma**2) - torch.sum(z**2, dim=(1,2,3)) / (2 * sigma**2)

def ode_likelihood(x,
                   score_model,
                   marginal_prob_std,
                   diffusion_coeff,
                   batch_size=64,
                   device='cuda:0',
                   eps=1e-5):
    """Compute the likelihood with probability flow ODE.
    
    Args :
        x:Input data.
        score_model: A PyTorch model representing the score-based model.
        marginal_prob_std: A function that gives the standard deviation of the
            perturbation kernel.
        diffusion coeff: A function that gives the diffusion coefficient of the
            forward SDE.
        batch_size: The batch size. Equals to the leading dimension of 'x'.
        device: 'cuda' for evaluation on GPUs, and 'cpu' for evaluation on CPUs.
        eps:A 'float' number. The smallest time step for numerical stability.
        
    Returns:
        z: The latent code for 'x'.
        bpd: The 1og-likelihoods in bits/dim.
    """
    # Draw the random Gaussian sample for Skilling-Hutchinson's estimator.
    epsilon = torch.randn_like(x)

    def divergence_eval(sample, time_steps, epsilon):
        """Compute the divergence of the score-based model with Skilling-Hutchinson."""
        with torch.enable_grad():
            sample.requires_grad_(True)
            score_e = torch.sum(score_model(sample, time_steps) * epsilon)
            grad_score_e = torch.autograd.grad(score_e, sample)[0]
        return torch.sum(grad_score_e * epsilon, dim=(1, 2, 3))

    shape = x.shape
    
    def score_eval_wrapper(sample, time_steps):
        """A wrapper for evaluating the score-based model for the black-box ODE solver."""
        sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)
        time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape((sample.shape[0], ))
        with torch.no_grad():
            score = score_model(sample, time_steps)
        return score.cpu().numpy().reshape((-1,)).astype(np.float64)

    def divergence_eval_wrapper(sample, time_steps):
        """A wrapper for evaluating the divergence of score for the black-box opE solver."""
        with torch.no_grad():
            # Obtain x(t) by solving the probability flow ODE.
            sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)
            time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape((sample.shape[0], ))
            # Compute likelihood.
            div = divergence_eval(sample, time_steps, epsilon)
            return div.cpu().numpy().reshape((-1,)).astype(np.float64)
    
    def ode_func(t, x):
        """The ODE function for the black-box solver."""
        time_steps = np.ones((shape[0],)) * t
        sample = x[:-shape[0]]
        logp = x[-shape[0]:]
        g = diffusion_coeff(torch.tensor(t)).cpu().numpy()
        sample_grad = -0.5 * g**2 * score_eval_wrapper(sample, time_steps)
        logp_grad = -0.5 * g**2 * divergence_eval_wrapper(sample, time_steps)
        return np.concatenate([sample_grad, logp_grad], axis=0)
    
    init = np.concatenate([x.cpu().numpy().reshape((-1,)), np.zeros((shape[0],))], axis=0)
    # Black-box ODE solve
    res = integrate.solve_ivp(ode_func, (eps, 1.), init, rtol=1e-5, atol=1e-5, method='RK45')
    zp = torch.tensor(res.y[:, -1], device=device)
    z = zp[:-shape[0]].reshape(shape)
    delta_logp = zp[-shape[0]:].reshape(shape[0])
    sigma_max = marginal_prob_std(1.)
    prior_logp = prior_likelihood(z, sigma_max)
    bpd =-(prior_logp + delta_logp) / np.log(2)
    N = np.prod(shape[1:])
    bpd = bpd / N + 8.
    return z, bpd
batch_size = 32 #@param {'type':'integer'}
dataset = MNIST('.', train=False, transform=transforms.ToTensor(), download=True)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

ckpt = torch.load('ckpt.pth', map_location=device)
score_model.load_state_dict(ckpt)

all_bpds =0.
all_items = 0
try:
    tqdm_data = tqdm.tqdm(data_loader)
    for x, _ in tqdm_data:
        x = x.to(device)
        # uniform dequantization
        x = (x * 255. + torch.rand_like(x))/256.
        _, bpd = ode_likelihood(x, score_model, marginal_prob_std_fn,
                                diffusion_coeff_fn,
                                x.shape[0], device=device, eps=1e-5)
        all_bpds += bpd.sum()
        all_items += bpd.shape[0]
        tqdm_data.set_description("Averaye bits/dim: {:5f}".format(all_bpds / all_items))
        
except KeyboardInterrupt:
    # Remove the error message when interuptted by keyboard or GUI.
    pass
  • 17
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值