详细见论文
2010.02502v4
简单理解
因为DDPM是遵循马尔可夫过程的,下一时刻状态取决上一时刻状态。DDIM经过变换后,让前向加噪过程成立的同时,去噪过程可以不遵循马尔可夫过程,将T=1000的序列,采样L序列(远小于T),加速推理
实现代码
@torch.no_grad()
def ddim_sample(self, shape, cond, **kwargs):
batch, device, total_timesteps, sampling_timesteps, eta = shape[0], self.betas.device, self.n_timestep, 50, 1
times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]