import torch.nn.functional as F
import torch
import math
import numpy as np
from .utils import *
eps = 1e-8
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
if schedule_name == "linear":
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / num_diffusion_timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return np.linspace(
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
)
elif schedule_name == "cosine":
return betas_for_alpha_bar(
num_diffusion_timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
class GaussianMultinomialDiffusion(torch.nn.Module):
def __init__(
self,
num_classes: np.array,
num_numerical_features: int,
denoise_fn,
num_timesteps=1000,
gaussian_loss_type='mse',
gaussian_parametrization='eps',
multinomial_loss_type='vb_stochastic',
parametrization='x0',
scheduler='cosine',
device=torch.device('cpu')
):
super(GaussianMultinomialDiffusion, self).__init__()
assert multinomial_loss_type in ('vb_stochastic', 'vb_all')
assert parametrization in ('x0', 'direct')
if multinomial_loss_type == 'vb_all':
print('Computing the loss using the bound on _all_ timesteps.'
' This is expensive both in terms of memory and computation.')
self.num_numerical_features = num_numerical_features
self.num_classes = num_classes # it as a vector [K1, K2, ..., Km]
self.num_classes_expanded = torch.from_numpy(
np.concatenate([num_classes[i].repeat(num_classes[i]) for i in range(len(num_classes))])
).to(device)
self.slices_for_classes = [np.arange(self.num_classes[0])]
offsets = np.cumsum(self.num_classes)
for i in range(1, len(offsets)):
self.slices_for_classes.append(np.arange(offsets[i - 1], offsets[i]))
self.offsets = torch.from_numpy(np.append([0], offsets)).to(device)
self._denoise_fn = denoise_fn
self.gaussian_loss_type = gaussian_loss_type
self.gaussian_parametrization = gaussian_parametrization
self.multinomial_loss_type = multinomial_loss_type
self.num_timesteps = num_timesteps
self.parametrization = parametrization
self.scheduler = scheduler
alphas = 1. - get_named_beta_schedule(scheduler, num_timesteps)
alphas = torch.tensor(alphas.astype('float64'))
betas = 1. - alphas
log_alpha = np.log(alphas)
log_cumprod_alpha = np.cumsum(log_alpha)
log_1_min_alpha = log_1_min_a(log_alpha)
log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha)
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = torch.tensor(np.append(1.0, alphas_cumprod[:-1]))
alphas_cumprod_next = torch.tensor(np.append(alphas_cumprod[1:], 0.0))
sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - alphas_cumprod)
sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod)
sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1)
# Gaussian diffusion
self.posterior_variance = (
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
)
self.posterior_log_variance_clipped = torch.from_numpy(
np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:]))
).float().to(device)
self.posterior_mean_coef1 = (
betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
).float().to(device)
self.posterior_mean_coef2 = (
(1.0 - alphas_cumprod_prev)
* np.sqrt(alphas.numpy())
/ (1.0 - alphas_cumprod)
).float().to(device)
assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5
assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5
assert (np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item() < 1.e-5
# Convert to float32 and register buffers.
self.register_buffer('alphas', alphas.float().to(device))
self.register_buffer('log_alpha', log_alpha.float().to(device))
self.register_buffer('log_1_min_alpha', log_1_min_alpha.float().to(device))
self.register_buffer('log_1_min_cumprod_alpha', log_1_min_cumprod_alpha.float().to(device))
self.register_buffer('log_cumprod_alpha', log_cumprod_alpha.float().to(device))
self.register_buffer('alphas_cumprod', alphas_cumprod.float().to(device))
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev.float().to(device))
self.register_buffer('alphas_cumprod_next', alphas_cumprod_next.float().to(device))
self.register_buffer('sqrt_alphas_cumprod', sqrt_alphas_cumprod.float().to(device))
self.register_buffer('sqrt_one_minus_alphas_cumprod', sqrt_one_minus_alphas_cumprod.float().to(device))
self.register_buffer('sqrt_recip_alphas_cumprod', sqrt_recip_alphas_cumprod.float().to(device))
self.register_buffer('sqrt_recipm1_alphas_cumprod', sqrt_recipm1_alphas_cumprod.float().to(device))
self.register_buffer('Lt_history', torch.zeros(num_timesteps))
self.register_buffer('Lt_count', torch.zeros(num_timesteps))
# Gaussian part
def gaussian_q_mean_variance(self, x_start, t):
mean = (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
)
variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = extract(
self.log_1_min_cumprod_alpha, t, x_start.shape
)
return mean, variance, log_variance
def gaussian_q_sample(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
assert noise.shape == x_start.shape
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
def gaussian_q_posterior_mean_variance(self, x_start, x_t, t):
assert x_start.shape == x_t.shape
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ extract(self.posterior_mean_coef2, t, x_t
tab-ddpm的核心代码
最新推荐文章于 2024-09-06 22:40:10 发布
本文深入探讨Tab-DDPM的数学原理,并展示其在Python中使用Numpy库的关键代码实现,帮助读者掌握这一数学建模技术。
摘要由CSDN通过智能技术生成