在之前的篇章中,以huggingface为例,分析了模型在训练阶段,是如何加噪声,以及用unet预测噪声的。接下来以开源代码diffusers为例,分析扩散模型在去噪推理时的原理。
在阅读这篇文章前,读者需要先了解DDPM、LDM和stable diffusion model的原理。否则阅读会吃力。
以StableDiffusionXLInpaintPipeline为例,其__call__
函数有如下,可以得到两个信息:
- unet会以时间步t的latent matrix为输入,预测噪声,噪声的形状与输入相同。
- self.scheduler.step会根据时间步t的latent matrix和预测的噪声,以某个权重相减,得到时间步t-1的latent matrix。后者可以作为下个循环的输入。
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
...
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
...
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
...
深入方法self.scheduler.step
的调用,此处实现为DPMSolverMultistepScheduler,这里会进入dpm_solver_first_order_update。
def step(
self,
model_output: torch.Tensor,
timestep: int,
sample: torch.Tensor,
generator=None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
...
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
else:
prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)
if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1
...
return SchedulerOutput(prev_sample=prev_sample)
继续深入实现,可以看出:sovler有"dpmsolver++"等四种选型。简单来看,"dpmsolver++"和"dpmsolver"都是以某种权重比例让该时间步t的latent matrix减去预测噪声,而"sde-dpmsolver++"和"sde-dpmsolver"还会再加一个随机噪声(只是比例可能较小)。
def dpm_solver_first_order_update(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
One step for the first-order DPMSolver (equivalent to DDIM).
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)
# 省略部分代码...
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
h = lambda_t - lambda_s
if self.config.algorithm_type == "dpmsolver++":
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
elif self.config.algorithm_type == "dpmsolver":
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
elif self.config.algorithm_type == "sde-dpmsolver++":
assert noise is not None
x_t = (
(sigma_t / sigma_s * torch.exp(-h)) * sample
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
)
elif self.config.algorithm_type == "sde-dpmsolver":
assert noise is not None
x_t = (
(alpha_t / alpha_s) * sample
- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output
+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise
)
return x_t
参考阅读: