扩散模型-Diffusion Model-DDPM、DDIM

本文详细阐述了扩散模型中的前向过程,通过添加噪声并记录过程,以及反向过程中的噪声预测,使用U-net模型进行噪声预测和逆采样。介绍了关键的贝叶斯公式和损失函数计算方法。
摘要由CSDN通过智能技术生成


扩散模型主要分为前向过程与后向过程,其中前向过程主要是通过不断给数据集中的图像加噪声,并记录下整个过程,后向过程是一个逐渐去噪的过程,这里的关键就是预测一下每一步减去的噪声,从而获取一步去噪后的分布情况。
在这里插入图片描述

DDPM

一、前向过程

  • 首先令: α t = 1 − β t \alpha_t=1-\beta_t αt=1βt
    其中 β t \beta_t βt随着t的增加越来越大,论文中是从0.0001到0.002。
  • x t = a t x t − 1 + 1 − α t z t ( 1 ) x_t=\sqrt{a_t}x_{t-1}+\sqrt{1-\alpha_t}z_t (1) xt=at xt1+1αt zt1 此公式的意思是,t时刻的图像是有前一时刻图像加上一个从正态分布中采样的噪声加权相加的过程。其中 z ∼ N ( 0 , 1 ) z\sim\mathcal{N}(0,\mathbf{1}) zN(0,1)
  • 由式(1)递推,可得到: x t = a t ( a t − 1 x t − 2 + 1 − α t − 1 z t − 1 ) + 1 − α t z t x_t=\sqrt{a_t}\big(\sqrt{a_{t-1}}x_{t-2}+\sqrt{1-\alpha_{t-1}}z_{t-1}\big)+\sqrt{1-\alpha_t}z_t xt=at (at1 xt2+1αt1 zt1)+1αt zt = a t a t − 1 x t − 2 + ( a t ( 1 − α t − 1 ) z t − 1 + 1 − α t z t ) =\sqrt{a_ta_{t-1}}x_{t-2}+(\sqrt{a_t(1-\alpha_{t-1})}z_{t-1}+\sqrt{1-\alpha_t}z_t) =atat1 xt2+(at(1αt1) zt1+1αt zt) 此时根据高斯分布的性质,可得: = a t a t − 1 x t − 2 + 1 − α t α t − 1 z ‾ t =\sqrt{a_ta_{t-1}}x_{t-2}+\sqrt{1-\alpha_t\alpha_{t-1}}\overline{z}_{t} =atat1 xt2+1αtαt1 zt 其中 z ‾ t ∼ N ( 0 , 1 ) \overline{z}_{t}\sim\mathcal{N}(0,\mathbf{1}) ztN(0,1)
    继续迭代,可得最终前向公式: x t = α ‾ t x 0 + 1 − α ‾ t z ‾ t x_t=\sqrt{\overline{\alpha}_t}x_0+\sqrt{1-\overline{\alpha}_t}\overline{z}_t xt=αt x0+1αt zt 其中, α ‾ t \overline{\alpha}_t αt是表示累乘。

二、反向过程

原理主要是通过后一张的分布预测处前一张图像的分布情况。

  • 首先根据贝叶斯公式,可得: q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) q(\mathbf{x}_{t-1}|\mathbf{x}_t,\mathbf{x}_0)=q(\mathbf{x}_t|\mathbf{x}_{t-1},\mathbf{x}_0)\frac{q(\mathbf{x}_{t-1}|\mathbf{x}_0)}{q(\mathbf{x}_t|\mathbf{x}_0)} q(xt1xt,x0)=q(xtxt1,x0)q(xtx0)q(xt1x0) 其中q代表的意思是该图像的分布情况。
  • 由前向过程我们可得:在这里插入图片描述
    上面三个式子对应的概率密度函数,分别为:
    在这里插入图片描述

再带入到贝斯公式中:
在这里插入图片描述
整理成正态分布的概率密度形式:
1 2 π ( 1 − a t 1 − a ˉ t − 1 1 − a ˉ t ) e [ − ( x t − 1 − ( a t ( 1 − a ˉ t − 1 ) 1 − a ˉ t x t + a t − 1 ( 1 − a t ) 1 − a ˉ t x 0 ) ) 2 2 ( 1 − a t 1 − a ˉ t − 1 1 − a ˉ t ) 2 ] \frac1{\sqrt{2\pi}\left(\color{red}{\frac{\sqrt{1-a_t}\sqrt{1-\bar{a}_{t-1}}}{\sqrt{1-\bar{a}_t}}}\right)}e^{\left[-\frac{\left(x_{t-1}-\left(\color{red}\frac{\sqrt{a_t}(1-\bar{a}_{t-1})}{1-\bar{a}_t}x_t+\frac{\sqrt{a_{t-1}}(1-a_t)}{1-\bar{a}_t}x_0\right)\right)^2}{2\left(\color{red}\frac{\sqrt{1-a_t}\sqrt{1-\bar{a}_{t-1}}}{\sqrt{1-\bar{a}_t}}\right)^2}\right]} 2π (1aˉt 1at 1aˉt1 )1e 2(1aˉt 1at 1aˉt1 )2(xt1(1aˉtat (1aˉt1)xt+1aˉtat1 (1at)x0))2

其中,x0我们可以通过前向过程公式得到: x 0 = 1 α ˉ t ( x t − 1 − α ˉ t z t ) \mathbf{x}_0=\frac1{\sqrt{\bar{\alpha}_t}}(\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t}\mathbf{z}_t) x0=αˉt 1(xt1αˉt zt) 现在获取Zt是获得前一张图像分布的关键。

最终分布:
在这里插入图片描述
其中 ϵ \epsilon ϵ就是 z t z_t zt一个逆向噪声。
上式还可以化简,最后可得分布的均值,方差为:
在这里插入图片描述
在这里插入图片描述

三、重采样

首先任何正态分布都可以从标准正态分布变换得到,正态分布转标准正态分布式减均值除标准差,因此标准正态分布转其他分部只需要乘标准差加均值即可,因此最终采样公式如下:
在这里插入图片描述

四、预测噪声

现在后向过程的关键就是获取一次抽样的噪声了,只能通过模型来进行预测。
通常采用U-net模型来进行,其中模型的输入参数有两个,分别是当前时刻的分布以及当前时刻t。

训练以及采样流程:
在这里插入图片描述
预测噪声的部分代码:

噪声模型:

import torch
import torch.nn as nn

class MLPDiffusion(nn.Module)  :
    def __init__(self,n_steps,num_units=128):
        super(MLPDiffusion,self).__init__()
        
        self.linears = nn.ModuleList(
            [
                nn.Linear(2,num_units),
                nn.ReLU(),
                nn.Linear(num_units,num_units),
                nn.ReLU(),
                nn.Linear(num_units,num_units),
                nn.ReLU(),
                nn.Linear(num_units,2),
            ]
        )
        self.step_embeddings = nn.ModuleList(
            [
                nn.Embedding(n_steps,num_units),
                nn.Embedding(n_steps,num_units),
                nn.Embedding(n_steps,num_units),
            ]
        )
    def forward(self,x,t):
#         x = x_0
        for idx,embedding_layer in enumerate(self.step_embeddings):
            t_embedding = embedding_layer(t)
            x = self.linears[2*idx](x)
            x += t_embedding
            x = self.linears[2*idx+1](x)
            
        x = self.linears[-1](x)  #输出噪声 shape:128*2
        
        return x

损失函数:

def diffusion_loss_fn(model,x_0,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,n_steps):
    """对任意时刻t进行采样计算loss"""
    batch_size = x_0.shape[0]
    
    #对一个batchsize样本生成随机的时刻t
    t = torch.randint(0,n_steps,size=(batch_size//2,))  #0-100中间随机生成时刻
    t = torch.cat([t,n_steps-1-t],dim=0)
    t = t.unsqueeze(-1)  #为了让获取的系数与x_0形状类似以至于可以广播相乘
    
    #x0的系数
    a = alphas_bar_sqrt[t]
    
    #eps的系数
    aml = one_minus_alphas_bar_sqrt[t]
    
    #生成随机噪音eps
    e = torch.randn_like(x_0)
    
    #构造模型的输入
    x = x_0*a+e*aml
    
    #送入模型,得到t时刻的随机噪声预测值
    output = model(x,t.squeeze(-1))
    
    #与真实噪声一起计算误差,求平均值
    return (e - output).square().mean()

注意:这里的e就是正向过程中加进去的噪声,只不过这里没有单独走一遍前向过程,而是在反向过程的途中再算正向过程的那个噪声,这样就不用存每次加的噪声了。
逆采样过程:

def p_sample_loop(model,shape,n_steps,betas,one_minus_alphas_bar_sqrt):
    """从x[T]恢复x[T-1]、x[T-2]|...x[0]"""
    cur_x = torch.randn(shape)
    x_seq = [cur_x]
    for i in reversed(range(n_steps)):
        cur_x = p_sample(model,cur_x,i,betas,one_minus_alphas_bar_sqrt)
        x_seq.append(cur_x)
    return x_seq

def p_sample(model,x,t,betas,one_minus_alphas_bar_sqrt):
    """从x[T]采样t时刻的重构值"""
    t = torch.tensor([t])
    
    coeff = betas[t] / one_minus_alphas_bar_sqrt[t]
    
    eps_theta = model(x,t)
    
    mean = (1/(1-betas[t]).sqrt())*(x-(coeff*eps_theta))
    
    z = torch.randn_like(x)
    sigma_t = betas[t].sqrt()
    
    sample = mean + sigma_t * z
    
    return (sample)

DDIM

用来加速DDPM的反向过程,使其反向过程不遵循马尔可夫链。

  1. 假设 P ( x t − 1 ∣ x t , x 0 ) P(x_{t-1}|x_t, x_0) P(xt1xt,x0)满足以下正态分布:
    在这里插入图片描述
  2. 因为加载过程满足公式:
    在这里插入图片描述
  3. 代入得:
    在这里插入图片描述
  4. 同样,通过正向过程,也可得出 x t − 1 x_{t-1} xt1的图像:
    在这里插入图片描述
  5. 根据系数相同,可得:
    在这里插入图片描述
  6. 可得, P ( x t − 1 ∣ x t , x 0 ) P(x_{t-1}|x_t, x_0) P(xt1xt,x0)分布为:
    在这里插入图片描述
  7. 同样采用如下替换:
    在这里插入图片描述
  8. 可得 x t − 1 x_{t-1} xt1重采样为:
    在这里插入图片描述
    9.在这里插入图片描述
    关于DDIM也可参考这篇博文https://kexue.fm/archives/9181

总结

代码在资源中。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值