DDIM
论文地址:Denoising Diffusion Implicit Models
代码地址:https://github.com/ermongroup/ddim
一、概述
DDIM 是基于 DDPM 改进的迭代隐式概率扩撒模型,核心目标是在保持生成质量的同时加速采样过程。通过引入非马尔可夫扩散过程和确定性采样机制,DDIM 允许在去噪时跳过部分时间步,可以显著减少计算量。其核心创新在于:
- 可调方差参数:通过控制反向过程的随机性,实现从完全随机(DDPM)到完全确定(无噪声)的采样模式;
- 跳跃式采样:无需遍历所有时间步,可直接在预设的关键时间点之间跳转,大幅提升生成速度。
DDIM 的主要特点包括:
- 非马尔可夫过程:打破 DDPM 的严格马尔可夫链限制,允许当前状态依赖任意历史状态;
- 确定性采样:通过设置方差为 0,消除采样过程的随机性,提升生成稳定性;
- 采样效率:支持“跳步”采样,在 10-50 步内即可生成高质量样本(DDPM 需 1000 步)。
二、主要步骤
1. 正向扩散过程
DDIM 的正向扩散过程与 DDPM 一致,都是为了在每个时间步 中,逐渐增加噪声的比例,将原始数据
转变为带有噪声的数据
,其中
是扩散系数,
是标准正态分布的噪声。
-
单步扩散:
-
边际分布:
其中 \beta_t是随时间递增的方差调度(如线性调度
。
表示均值为
,协方差矩阵为
的多元高斯分布。
2. 反向去噪过程
DDIM 的反向去噪过程与 DDPM 不同,它在不利用马尔可夫假设的情况下推导出了 diffusion 的反向过程,最终可以实现仅采样 20 ~ 100 步的情况下达到和 DDPM 采样 1000 步相近的生成效果。
-
条件分布重构:
DDIM 重新定义反向过程为带可调方差的高斯分布:其中均值
由神经网络预测的噪声
推导:
这里
是对原始数据的估计。
-
确定性采样(σₜ=0):
当方差时,采样过程完全确定,无需添加随机噪声:
与 DDPM 相比,DDIM 通过设定方差参数
,生成过程完全确定,消除随机性干扰。
3. 跳跃式采样
跳跃式采样是 DDIM 所采取的最关键的核心优化,它允许 DDIM 在采样时跳过一些中间时间步,加快采样速度。形式化地来说,DDPM 的采样时间步应当是 ,而 DDIM 可以直接从其中抽取一个子序列
进行采样,此时只需递归应用公式:
通过预设时间步子集(如每隔 10 步采样一次),可在大幅减少计算量的同时保持生成质量。
三、数学理论
1. 反向过程的条件分布
DDIM 通过引入可调方差参数 ,将反向过程的条件分布扩展为:
当 时,退化为 DDPM 的马尔可夫采样;当
时,采样过程完全确定。
2. 变分下界简化
DDIM 的训练目标与 DDPM 一致,都是最小化噪声预测损失,但 DIMM 通过非马尔可夫设计简化了变分下界:
其中权重 由
决定,且省略了 DDPM 中与边界条件相关的项(如
,大幅降低计算复杂度。
则是两个高斯分布的 KL 散度。
四、模型结构
DDIM 沿用 DDPM 的 U-Net 架构作为主干网络,包含对称的编码器-解码器路径和跳跃连接,但针对采样效率进行了轻量化调整:
1. 网络设计细节
- 归一化与激活:使用GroupNorm替代 BatchNorm 用以提升小批量训练稳定性,使用SiLU 激活函数替代 ReLU,增强非线性建模能力;
- 时间嵌入:将时间步 t编码为高维向量(如正弦编码或可学习嵌入),通过线性层与各层特征融合;
- 跳跃连接:保留原来的编码器-解码器的多尺度特征融合,确保细节恢复能力。
2. 关键模块对比
- 采样层:DDIM 的
p_sample
方法通过判断是否为 0,决定是否添加随机噪声,默认σ_t=0 时为纯确定性计算;
- 时间步处理:支持任意时间步跳转,无需按顺序遍历,通过预设的时间步列表(如
)实现跳步采样。
五、代码实现
# 核心采样逻辑
class DDIM(nn.Cell):
"""DDIM核心类,实现跳跃式确定性采样"""
def __init__(self, model, betas, T=1000, sample_steps=50):
super().__init__()
self.model = model # U-Net网络
self.T = T # 总时间步
self.sample_steps = sample_steps # 采样时使用的跳步步长
self.betas = betas
self.alphas = 1. - betas
self.alpha_bars = np.cumprod(self.alphas)
# 生成跳步时间序列(如从T到0,每隔T/sample_steps步取一个点)
self.sampling_timesteps = np.linspace(0, T-1, sample_steps, dtype=np.int64)[::-1]
def p_sample(self, x, t):
"""确定性去噪单步(σ=0)"""
alpha = self.alphas[t]
alpha_bar = self.alpha_bars[t]
sqrt_alpha = ops.sqrt(alpha)
sqrt_one_minus_alpha = ops.sqrt(1 - alpha)
# 预测噪声并估计原始数据
pred_noise = self.model(x, t)
pred_x0 = (x - sqrt_one_minus_alpha * pred_noise) / sqrt_alpha
# DDIM确定性采样公式
alpha_bar_prev = self.alpha_bars[t-1] if t > 0 else 1.0
sqrt_alpha_bar_prev = ops.sqrt(alpha_bar_prev)
sqrt_one_minus_alpha_bar_prev = ops.sqrt(1 - alpha_bar_prev)
x_prev = sqrt_alpha_bar_prev * pred_x0 + sqrt_one_minus_alpha_bar_prev * pred_noise
return x_prev
def construct(self, x):
"""跳步采样过程(从x_T到x_0)"""
for t in self.sampling_timesteps:
x = self.p_sample(x, t)
return x
# U-Net 改进
class UNet(nn.Cell):
"""带GroupNorm和SiLU的轻量化U-Net"""
def __init__(self, in_channels=3, channel_dim=128):
super().__init__()
self.time_embed = nn.SequentialCell(
nn.Embedding(1000, channel_dim),
nn.SiLU(),
nn.Dense(channel_dim, channel_dim * 4)
)
self.down = nn.SequentialCell(
nn.Conv2d(in_channels, channel_dim, 3, padding=1),
nn.GroupNorm(32, channel_dim),
nn.SiLU(),
nn.Conv2d(channel_dim, channel_dim * 2, 3, padding=1, stride=2),
nn.GroupNorm(32, channel_dim * 2),
nn.SiLU()
)
self.up = nn.SequentialCell(
nn.Conv2dTranspose(channel_dim * 2, channel_dim, 3, stride=2, padding=1),
nn.GroupNorm(32, channel_dim),
nn.SiLU(),
nn.Conv2d(channel_dim, in_channels, 3, padding=1),
nn.Tanh()
)
def construct(self, x, t):
t_emb = self.time_embed(t)
h = self.down(x) + t_emb.view(-1, h.shape[1], 1, 1)
return self.up(h)