在Diffusion原理(一),Diffusion原理(二)中分别介绍了Diffusion(DDPM)的前向过程和逆向过程原理。这里将进一步介绍一下DDPM的代码实现.
首先我们从整体来思考一下,DDPM的代码实现,会包含哪些部分。可以利用思维导图大致梳理一下:
接下来根据思维导图来逐个模块实现。
1. 计算, 我们定义一个函数,来获取每一步的. 在DDPM中,采用的是线性生成方式。每一步的成线性增加的。输入参数timesteps是总的步数。由于默认总步数是1000步,但是为了适配不同的定义,可以自己指定总步数,因此这里就会对每一步的步长进行缩放。具体代码如下:
def linear_beta_schedule(timesteps):
totalstep = 1000
value_start = 0.0001
value_end = 0.02
scale = totalstep / timesteps
beta_start = scale * value_start
beta_end = scale * value_end
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
2. 根据beta计算其他前向过程的参数。这里我们定义一个dGaussianDiffusion的类。先定义属性,利用这些属性(存放在register_buffer中)来设置这些固定的参数,不会参与参数更新。
class GaussianDiffusion(nn.Module):
def __init__(self,
opts,
device,
network,
min_snr_loss_weight=True):
super().__init__()
self.opts = opts
self.device = device
self.network = network.to(device)
# define betas: beta1, beta2, ... beta_n
beta_schedule_fn = linear_beta_schedule
betas = beta_schedule_fn(self.opts['timesteps'])
self.num_timesteps = int(betas.shape[0])
# define alphas
# get a1, a2, ..., an
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)
sqrt_recip_alphas = 1.0 / torch.sqrt(alphas)
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
register_buffer('betas', betas)
register_buffer('alphas_cumprod', alphas_cumprod)
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
register_buffer('sqrt_recip_alphas', sqrt_recip_alphas)
# calculations for diffusion q(x_t | x_{t-1}) and others
# x_t = sqrt(alphas_cumprod)* x_0 + sqrt(1 - alphas_cumprod) * noise
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - alphas_cumprod))
# calculations for posterior q(x_{t - 1} | x_t, x_0)
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
register_buffer('posterior_variance', posterior_variance)
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1.0 / alphas_cumprod)) # A
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1.0 / alphas_cumprod - 1.0)) # B
# mu_{t - 1} = mean_coef1 * clip(x_{0}) + mean_coef2 * x_{t}
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20)))
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) /
(1.0 - alphas_cumprod))
snr = alphas_cumprod / (1.0 - alphas_cumprod)
maybe_clipped_snr =snr.clone()
if min_snr_loss_weight:
maybe_clipped_snr.clamp_(max=self.opts['min_snr_gamma'])
register_buffer('loss_weight', maybe_clipped_snr / snr)
self.ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
3. 接下来是正向过程:这里完全按照原理部分的公式来获取对应前向过程t时刻的样本。
def q_sample(self, x_start, t, noise=None):
sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumprod_t = extract(
self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
)
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
其中extract的实现如下,这个代表根据t来抽取对应位置的值。
def extract(tensor, t, x_shape):
batch_size =t.shape[0]
out = tensor.gather(-1, t.cpu())
return out.reshape(batch_size,
*((1, ) * (len(x_shape) - 1))).to(t.device)
比如 sqrt_alphas_cumprod_t 就是 中的. sqrt_alphas_cumprod其实是所有从0-T的这些alpha在每个时刻都保存的一个tensor。
4. 训练时求loss。前面已经知道了前向过程,loss计算只拟合噪声,这个时候我们就可以得到直接计算loss了。在DDPM原理介绍中,我们最后推到出了整个loss其实可以表达为公式:
这里就会要求,求得每一步的噪声误差的总和。但是呢,为了如果每一步都去求,那么未免太耗时了。比如说总步长是1000, 那么每一个样本求1000次的loss总和未免太耗费训练时间了。这里为了简便,会对每个训练batch中的每张图,对应在[1, 1000]中随机生成一个步长timestep。比如一个batch有四个样本,每个样本的步长在[1, 1000]中随机生成,比如可以是[50, 100, 94, 786]。具体代码如下所示。
def p_losses(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_t = self.q_sample(x_start=x_start, t=t, noise=noise)
network_out = self.network(x_t, t)
target = noise
if self.opts['loss_type'] == 'huber':
loss = F.smooth_l1_loss(network_out, target, reduction='none')
elif self.opts['loss_type'] == 'l1':
loss = F.l1_loss(network_out, target, reduction='none')
elif self.opts['loss_type'] == 'l2':
loss = F.mse_loss(network_out, target, reduction='none')
else:
raise NotImplementedError()
loss = reduce(loss, 'b ... -> b (...)', 'mean')
loss = loss * extract(self.loss_weight, t, loss.shape)
return loss.mean()
def forward(self, img):
b, _, _, _ = img.shape
t = torch.randint(0, self.num_timesteps, (b,), device=self.device).long()
return self.p_losses(img, t)
5. 接下来要实现逆向生成过程:逆向生成过程也是按照前面原理部分的公式得到的,只是把数学公式用代码表达而已。这里的stable_sampling只是为了更好实现逆向生成,做了一些数学公式上的变换。
@torch.inference_mode()
def p_sample(self, x, t, t_index):
betas_t = extract(self.betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape)
sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t, x.shape)
x_mean = sqrt_recip_alphas_t * (
x - betas_t * self.network(x, t) / sqrt_one_minus_alphas_cumprod_t
)
if t_index == 0:
return x_mean
else:
posterior_variance_t = extract(self.posteroir_variance, t, x.shape)
noise = torch.randn_like(x)
return x_mean + torch.sqrt(posterior_variance_t) * noise
上面的p_sample只是一步的逆向生成,要实现从Xt到X0的生成,需要一个循环,如下所示:
@torch.inference_mode()
def p_sample_loop(self, shape, return_all_timesteps=False):
batch_size = self.opts['sample_batch_size']
image = torch.randn(shape, device=self.device)
return_images = [image.cpu().numpy()]
for i in tqdm(reversed(range(0, self.opts['timesteps'])),
desc='sampling loop time step',
total=self.opts['timesteps']):
image = self.p_sample(image, torch.full((batch_size, ), i,
device=self.device, dtype=torch.long), i)
if return_all_timesteps:
return_images.append(image.cpu().numpy())
else:
if i == 0:
return_images.append(image.cpu().numpy())
return return_images
至此,原始DDPM的代码已经实现了,还差一个生成噪声的network的定义,之后将在下一次代码详解中介绍。