这里是DDPM框架的代码,等我成功把他用在超分辨率中后我会将项目代码和视频发出来。想起来SRGANext的视频还一直没做完。。。
原始论文中直接使用β作为方差,我这里使用了真实方差
import torch
import torch.nn.functional as F
class DenoiseDiffusion:
def __init__(self,eps_module,n_steps,device):
self.eps_module = eps_module
self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
self.alpha = 1. - self.beta
self.alpha_bar = torch.cumprod(self.alpha, dim=0) # 维度与alpha相同
self.n_steps = n_steps
def q_xt_x0(self,x0,t):
'''
均值:根号下alpha_bar * x0
方差:1-alpha_bar
得到扩散第t步的均值和方差
'''
mean = (self.alpha_bar[t-1]**0.5)*x0
var = 1-self.alpha_bar[t-1]
return mean,var
def q_sample(self,x0,t,eps=None):
'''
根据参数重整化返回采样
'''
if not eps:
eps = torch.randn_like(x0)
mean,var = self.q_xt_x0(x0,t)
return mean+(var**0.5)*eps
def p_sample(self,xt,t):
'''
xt-1的分布
'''
eps_theta = self.eps_module(xt,t) # 得到预测的分布
alpha_bar = self.alpha_bar[t-1] # alpha_bar[t]
alpha_bar_pre = self.alpha_bar[t-2] # alpha_bar[t-1]
alpha = self.alpha[t-1]
beta = 1-alpha
eps_coef = beta/(1-alpha_bar)**0.5 # 预测的分布的系数
mean = 1/(alpha**0.5)*(xt-eps_coef*eps_theta)
var = beta*(1.0-alpha_bar_pre)/(1-alpha_bar)
eps = torch.randn(xt.shape, device=xt.device)
return mean + (var ** .5) * eps
def loss(self,x0,noise):
batch_size = x0.shape[0]
t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
if not noise:
noise = torch.randn_like(x0)
xt = self.q_sample(x0, t, eps=noise)
eps_theta = self.eps_module(xt,t)
return F.mse_loss(noise,eps_theta)