在IDDPM中,主要改进了三个点:
·使用余弦方案生成β
·模型预测β与真实方差之间的线性权重
·混合损失,不但包括DDPM中预测噪音与真实噪音的均方误差,还有预测的xt-1分布和真实xt-1分布的KL散度。但是这里真实xt-1分布并不是通过前向过程使用x0和t-1算出来的,是通过一个复杂公式计算出来,并不是很理解。
这三点我的讲解已经在我的Improved Denoising Diffusion Probabilistic Models_管不住心的大杜的博客-CSDN博客
写了,所以直接放代码。
import torch
import torch.nn.functional as F
import math
from torch.distributions import Categorical
import torch.distributions as dist
import numpy
class DenoiseDiffusion:
def __init__(self,eps_module,n_steps,device):
# n_steps是时间步最大值,不是每一次的时间步
self.eps_module = eps_module
self.beta = self.betas(n_steps).to(device) # 余弦加噪方案
self.sigma2 = self.beta # 非真实方差,IDDPM会预测β和真实方差之间的线性加权的权重
self.alpha = 1. - self.beta
self.alpha_bar = torch.cumprod(self.alpha, dim=0) # 维度与alpha相同
self.n_steps = n_steps
def betas(self,n_steps,max_beta=0.999):
# 余弦加噪方案生成Beta
betas = []
alpha_bar = lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
for i in range(n_steps):
t1 = i / n_steps
t2 = (i + 1) / n_steps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.Tensor(betas)
def q_xt_x0(self,x0,t):
'''
均值:根号下alpha_bar * x0
方差:1-alpha_bar
得到第t时刻分布的均值和方差
'''
mean = (self.alpha_bar[t-1][:,None,None,None] ** 0.5) * x0 # 必须是[B,1,1,1]和[B,C,H,W]才能相乘
var = 1-self.alpha_bar[t-1]
return mean,var
def q_sample(self,x0,t,eps=None):
'''
根据参数重整化返回采样
'''
if not eps:
eps = torch.randn_like(x0)
mean,var = self.q_xt_x0(x0,t)
return mean+ (var**0.5)[:,None,None,None] * eps # 必须是[B,1,1,1]和[B,C,H,W]才能相乘
def q_posterior_xt_x0(self,x0,xt,t):
'''
根据x0和Xt得到xt-1的真实分布,而非后向过程中的预测的分布
'''
alpha_bar = self.alpha_bar[t - 1] # alpha_bar[t]
alpha_bar_pre = self.alpha_bar[t - 2] # alpha_bar[t-1]
alpha = self.alpha[t - 1] # alpha[t]
beta = 1 - alpha # beta[t]
posterior_var = (
beta * (1.0 - alpha_bar_pre)**0.5 / (1.0 - alpha_bar)
)
posterior_mean_coef1 = (
beta * (alpha_bar_pre) / (1.0 -alpha_bar)
)
posterior_mean_coef2 = (
(1.0 - alpha_bar_pre)
* alpha**0.5
/ (1.0 - alpha_bar)
)
posterior_mean = x0*posterior_mean_coef1[:,None,None,None] + xt*posterior_mean_coef2[:,None,None,None] # 必须是[B,1,1,1]和[B,C,H,W]才能相乘
return posterior_mean,posterior_var
def q_posterior_sample(self,x0,xt,t,eps=None):
'''
根据参数重整化返回采样
'''
if not eps:
eps = torch.randn_like(x0)
mean, var = self.q_posterior_xt_x0(x0,xt,t)
return mean + (var ** 0.5)[:,None,None,None] * eps
def p_xt_t(self,xt,t):
'''
返回根据模型预测的分布预测xt-1的均值,方差,以及模型预测的噪音分布
'''
B, C = xt.shape[:2]
assert t.shape == (B,)
eps_theta, var_weight = self.eps_module(xt, t)
alpha_bar = self.alpha_bar[t-1] # alpha_bar[t]
alpha_bar_pre = self.alpha_bar[t-2] # alpha_bar[t-1]
alpha = self.alpha[t-1] # alpha[t]
beta = 1-alpha # beta[t]
beta_bar = beta*(1.0-alpha_bar_pre)/(1.0-alpha_bar) # 真实方差
frac = ((var_weight + 1) / 2)
model_log_variance = frac * torch.log2(beta) + (1 - frac) * torch.log2(beta_bar) # 线性相加
var = torch.exp(model_log_variance)
eps_coef = beta/(1-alpha_bar)**0.5 # 预测的分布的系数
mean = 1/(alpha**0.5)[:,None,None,None]*(xt-eps_coef[:,None,None,None]*eps_theta) # 必须是[B,1,1,1]和[B,C,H,W]才能相乘
# 将模型预测的噪音也返回
return mean,var,eps_theta
def p_sample(self,xt,t,eps=None):
'''返回Xt-1的分布和模型预测的噪音分布'''
if not eps:
eps = torch.randn_like(xt)
mean, var,eps_theta = self.p_xt_t(xt, t)
return mean + (var ** 0.5)[:,None,None,None] * eps,eps_theta # 必须是[B,1,1,1]和[B,C,H,W]才能相乘
def loss(self,x0,xt,t,noise=None):
true_dis = self.q_posterior_sample(x0,xt,t) # Xt-1的分布
pred_dis,eps_theta = self.p_sample(xt, t) # 预测Xt-1的分布,模型预测的噪音分布
# kl_divergence函数不能计算两个张量的KL散度,需要转换为概率分布对象
true_dis = Categorical(logits=true_dis)
pred_dis = Categorical(logits=pred_dis)
# KL散度,这里没有像源代码一样使用L[0]的负对数似然
kl_loss = dist.kl_divergence(true_dis, pred_dis)
kl_loss = torch.mean(kl_loss) / torch.log(torch.tensor(2.0))
batch_size = x0.shape[0]
t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
if not noise:
noise = torch.randn_like(x0)
# kl_loss即Lvlb,它包含模型,也包含预测的线性权重
return F.mse_loss(noise,eps_theta)+1e-3*kl_loss
def super_resolution(self,img,steps):
B = img.shape[0]
with torch.no_grad():
for i in range(steps,0,-1):
if not i%50:
print('当前是第',i,'步')
t = torch.full((B,),i,device=img.device) # 得到[i,i,i...i]的时间步骤向量
eps,_ = self.eps_module(img,t)
img -= eps # 减去模型预测到的噪音
# print('Current GPU memory usage: {:.2f} GB'.format(torch.cuda.memory_allocated() / 1024 ** 3)) # 可以查看当前显存占用
return img