DDPM中框架代码

这里是DDPM框架的代码,等我成功把他用在超分辨率中后我会将项目代码和视频发出来。想起来SRGANext的视频还一直没做完。。。

原始论文中直接使用β作为方差,我这里使用了真实方差

 

import torch
import torch.nn.functional as F

class DenoiseDiffusion:
    def __init__(self,eps_module,n_steps,device):
        self.eps_module = eps_module
        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
        self.alpha = 1. - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0) # 维度与alpha相同
        self.n_steps = n_steps

    def q_xt_x0(self,x0,t):

        '''
        均值:根号下alpha_bar * x0
        方差:1-alpha_bar
        得到扩散第t步的均值和方差
        '''
        mean = (self.alpha_bar[t-1]**0.5)*x0
        var = 1-self.alpha_bar[t-1]
        return mean,var

    def q_sample(self,x0,t,eps=None):
        '''
        根据参数重整化返回采样
        '''
        if not eps:
            eps = torch.randn_like(x0)
            mean,var = self.q_xt_x0(x0,t)
            return mean+(var**0.5)*eps

    def p_sample(self,xt,t):
        '''
        xt-1的分布
        '''
        eps_theta = self.eps_module(xt,t) # 得到预测的分布
        alpha_bar = self.alpha_bar[t-1] # alpha_bar[t]
        alpha_bar_pre = self.alpha_bar[t-2] # alpha_bar[t-1]
        alpha = self.alpha[t-1]
        beta = 1-alpha
        eps_coef = beta/(1-alpha_bar)**0.5 # 预测的分布的系数
        mean = 1/(alpha**0.5)*(xt-eps_coef*eps_theta)
        var = beta*(1.0-alpha_bar_pre)/(1-alpha_bar)
        eps = torch.randn(xt.shape, device=xt.device)
        return mean + (var ** .5) * eps

    def loss(self,x0,noise):
        batch_size = x0.shape[0]
        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
        if not noise:
            noise = torch.randn_like(x0)
        xt = self.q_sample(x0, t, eps=noise)
        eps_theta = self.eps_module(xt,t)
        return F.mse_loss(noise,eps_theta)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值