IDDPM框架代码

在IDDPM中,主要改进了三个点:

        ·使用余弦方案生成β

        ·模型预测β与真实方差之间的线性权重

        ·混合损失,不但包括DDPM中预测噪音与真实噪音的均方误差,还有预测的xt-1分布和真实xt-1分布的KL散度。但是这里真实xt-1分布并不是通过前向过程使用x0和t-1算出来的,是通过一个复杂公式计算出来,并不是很理解。

这三点我的讲解已经在我的Improved Denoising Diffusion Probabilistic Models_管不住心的大杜的博客-CSDN博客

写了,所以直接放代码。 

import torch
import torch.nn.functional as F
import math
from torch.distributions import Categorical
import torch.distributions as dist
import numpy



class DenoiseDiffusion:
    def __init__(self,eps_module,n_steps,device):
        # n_steps是时间步最大值,不是每一次的时间步
        self.eps_module = eps_module
        self.beta = self.betas(n_steps).to(device) # 余弦加噪方案
        self.sigma2 = self.beta # 非真实方差,IDDPM会预测β和真实方差之间的线性加权的权重
        self.alpha = 1. - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0) # 维度与alpha相同
        self.n_steps = n_steps

    def betas(self,n_steps,max_beta=0.999):
        # 余弦加噪方案生成Beta
        betas = []
        alpha_bar = lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
        for i in range(n_steps):
            t1 = i / n_steps
            t2 = (i + 1) / n_steps
            betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
        return torch.Tensor(betas)

    def q_xt_x0(self,x0,t):

        '''
        均值:根号下alpha_bar * x0
        方差:1-alpha_bar
        得到第t时刻分布的均值和方差
        '''
        mean = (self.alpha_bar[t-1][:,None,None,None] ** 0.5) * x0 # 必须是[B,1,1,1]和[B,C,H,W]才能相乘
        var = 1-self.alpha_bar[t-1]
        return mean,var

    def q_sample(self,x0,t,eps=None):
        '''
        根据参数重整化返回采样
        '''
        if not eps:
            eps = torch.randn_like(x0)
        mean,var = self.q_xt_x0(x0,t)
        return mean+ (var**0.5)[:,None,None,None] * eps # 必须是[B,1,1,1]和[B,C,H,W]才能相乘

    def q_posterior_xt_x0(self,x0,xt,t):
        '''
        根据x0和Xt得到xt-1的真实分布,而非后向过程中的预测的分布
        '''
        alpha_bar = self.alpha_bar[t - 1]  # alpha_bar[t]
        alpha_bar_pre = self.alpha_bar[t - 2]  # alpha_bar[t-1]
        alpha = self.alpha[t - 1] # alpha[t]
        beta = 1 - alpha # beta[t]
        posterior_var = (
                beta * (1.0 - alpha_bar_pre)**0.5 / (1.0 - alpha_bar)
        )
        posterior_mean_coef1 = (
                beta * (alpha_bar_pre) / (1.0 -alpha_bar)
        )
        posterior_mean_coef2 = (
                (1.0 - alpha_bar_pre)
                * alpha**0.5
                / (1.0 - alpha_bar)
        )
        posterior_mean = x0*posterior_mean_coef1[:,None,None,None] + xt*posterior_mean_coef2[:,None,None,None] # 必须是[B,1,1,1]和[B,C,H,W]才能相乘
        return posterior_mean,posterior_var

    def q_posterior_sample(self,x0,xt,t,eps=None):
        '''
                根据参数重整化返回采样
        '''
        if not eps:
            eps = torch.randn_like(x0)
        mean, var = self.q_posterior_xt_x0(x0,xt,t)
        return mean + (var ** 0.5)[:,None,None,None] * eps

    def p_xt_t(self,xt,t):
        '''
        返回根据模型预测的分布预测xt-1的均值,方差,以及模型预测的噪音分布
        '''
        B, C = xt.shape[:2]
        assert t.shape == (B,)
        eps_theta, var_weight = self.eps_module(xt, t)
        alpha_bar = self.alpha_bar[t-1] # alpha_bar[t]
        alpha_bar_pre = self.alpha_bar[t-2] # alpha_bar[t-1]
        alpha = self.alpha[t-1] # alpha[t]
        beta = 1-alpha # beta[t]
        beta_bar = beta*(1.0-alpha_bar_pre)/(1.0-alpha_bar) # 真实方差
        frac = ((var_weight + 1) / 2)
        model_log_variance = frac * torch.log2(beta) + (1 - frac) * torch.log2(beta_bar) # 线性相加
        var = torch.exp(model_log_variance)
        eps_coef = beta/(1-alpha_bar)**0.5 # 预测的分布的系数
        mean = 1/(alpha**0.5)[:,None,None,None]*(xt-eps_coef[:,None,None,None]*eps_theta) # 必须是[B,1,1,1]和[B,C,H,W]才能相乘
        # 将模型预测的噪音也返回
        return mean,var,eps_theta

    def p_sample(self,xt,t,eps=None):
        '''返回Xt-1的分布和模型预测的噪音分布'''
        if not eps:
            eps = torch.randn_like(xt)
        mean, var,eps_theta = self.p_xt_t(xt, t)
        return mean + (var ** 0.5)[:,None,None,None] * eps,eps_theta # 必须是[B,1,1,1]和[B,C,H,W]才能相乘


    def loss(self,x0,xt,t,noise=None):
        true_dis = self.q_posterior_sample(x0,xt,t) # Xt-1的分布
        pred_dis,eps_theta = self.p_sample(xt, t) # 预测Xt-1的分布,模型预测的噪音分布
        # kl_divergence函数不能计算两个张量的KL散度,需要转换为概率分布对象
        true_dis = Categorical(logits=true_dis)
        pred_dis = Categorical(logits=pred_dis)
        # KL散度,这里没有像源代码一样使用L[0]的负对数似然
        kl_loss = dist.kl_divergence(true_dis, pred_dis)
        kl_loss = torch.mean(kl_loss) / torch.log(torch.tensor(2.0))
        batch_size = x0.shape[0]
        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
        if not noise:
            noise = torch.randn_like(x0)
        # kl_loss即Lvlb,它包含模型,也包含预测的线性权重
        return F.mse_loss(noise,eps_theta)+1e-3*kl_loss

    def super_resolution(self,img,steps):
        B = img.shape[0]
        with torch.no_grad():
            for i in range(steps,0,-1):
                if not i%50:
                    print('当前是第',i,'步')
                t = torch.full((B,),i,device=img.device) # 得到[i,i,i...i]的时间步骤向量
                eps,_ = self.eps_module(img,t)
                img -= eps # 减去模型预测到的噪音
                # print('Current GPU memory usage: {:.2f} GB'.format(torch.cuda.memory_allocated() / 1024 ** 3)) # 可以查看当前显存占用
        return img

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值