MindSpore实现扩散模型系列——DDIM

DDIM

论文地址:Denoising Diffusion Implicit Models

代码地址:https://github.com/ermongroup/ddim

一、概述

DDIM 是基于 DDPM 改进的迭代隐式概率扩撒模型,核心目标是在保持生成质量的同时加速采样过程。通过引入非马尔可夫扩散过程和确定性采样机制,DDIM 允许在去噪时跳过部分时间步,可以显著减少计算量。其核心创新在于:

  1. 可调方差参数:通过控制反向过程的随机性,实现从完全随机(DDPM)到完全确定(无噪声)的采样模式;
  2. 跳跃式采样:无需遍历所有时间步,可直接在预设的关键时间点之间跳转,大幅提升生成速度。

DDIM 的主要特点包括:

  • 非马尔可夫过程:打破 DDPM 的严格马尔可夫链限制,允许当前状态依赖任意历史状态;
  • 确定性采样:通过设置方差为 0,消除采样过程的随机性,提升生成稳定性;
  • 采样效率:支持“跳步”采样,在 10-50 步内即可生成高质量样本(DDPM 需 1000 步)。

二、主要步骤

1. 正向扩散过程

DDIM 的正向扩散过程与 DDPM 一致,都是为了在每个时间步 t 中,逐渐增加噪声的比例,将原始数据 x_0转变为带有噪声的数据 x_t = \alpha_t x_0 + (1 - \alpha_t) \epsilon,其中 \alpha_t是扩散系数,\epsilon是标准正态分布的噪声。

  • 单步扩散:

    q(x_t \mid x_{t-1}) = \mathcal{N}\left(x_t; \sqrt{1-\beta_t}x_{t-1}, \beta_t \mathbf{I}\right)

  • 边际分布:

    x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon, \quad \bar{\alpha}_t = \prod_{s=1}^t \alpha_s, \, \alpha_s = 1 - \beta_s

    其中 \beta_t是随时间递增的方差调度(如线性调度 \beta_t = \beta_{\text{start}} + t*(\beta_{\text{end}} - \beta_{\text{start}})/T\mathcal{N}(\cdot; \mu, \sigma^2 I)表示均值为 \mu,协方差矩阵为 \sigma^2 I的多元高斯分布。

2. 反向去噪过程

DDIM 的反向去噪过程与 DDPM 不同,它在不利用马尔可夫假设的情况下推导出了 diffusion 的反向过程,最终可以实现仅采样 20 ~ 100 步的情况下达到和 DDPM 采样 1000 步相近的生成效果。

  • 条件分布重构:
    DDIM 重新定义反向过程为带可调方差的高斯分布:

    p_\theta(x_{t-1} \mid x_t) = \mathcal{N}\left(x_{t-1}; \tilde{\mu}_t(x_t, \epsilon_\theta), \sigma_t^2 \mathbf{I}\right)

    其中均值 \tilde{\mu}_t由神经网络预测的噪声 \epsilon_\theta(x_t, t)推导:

    \tilde{\mu}_t = \sqrt{\alpha_{t-1}} \cdot \hat{x}_0 + \sqrt{1 - \alpha_{t-1} - \sigma_t^2} \cdot \frac{x_t - \sqrt{\alpha_t} \hat{x}_0}{\sqrt{1 - \alpha_t}}

    这里 \hat{x}_0 = \frac{x_t - \sqrt{1 - \alpha_t} \epsilon_\theta}{\sqrt{\alpha_t}}是对原始数据的估计。

  • 确定性采样(σₜ=0):
    当方差 \sigma_t = 0时,采样过程完全确定,无需添加随机噪声:

    x_{t-1} = \sqrt{\alpha_{t-1}} \cdot \hat{x}_0 + \sqrt{1 - \alpha_{t-1}} \cdot \epsilon_\theta(x_t, t)

    与 DDPM 相比,DDIM 通过设定方差参数 \sigma_t=0,生成过程完全确定,消除随机性干扰。

3. 跳跃式采样

跳跃式采样是 DDIM 所采取的最关键的核心优化,它允许 DDIM 在采样时跳过一些中间时间步,加快采样速度。形式化地来说,DDPM 的采样时间步应当是 [T, T-1, ..., 2, 1],而 DDIM 可以直接从其中抽取一个子序列 [T_s, T_{s-1}, ..., T_2, T_1]进行采样,此时只需递归应用公式:

x_s = \sqrt{\bar{\alpha}_s} \cdot \hat{x}_0 + \sqrt{1 - \bar{\alpha}_s} \cdot \epsilon_\theta(x_t, t)

通过预设时间步子集(如每隔 10 步采样一次),可在大幅减少计算量的同时保持生成质量。

三、数学理论

1. 反向过程的条件分布

DDIM 通过引入可调方差参数 \sigma_t,将反向过程的条件分布扩展为:

q_\sigma(x_{t-1} \mid x_t, x_0) = \mathcal{N}\left(x_{t-1}; \frac{\sqrt{\alpha_{t-1}} (x_t - \sqrt{1 - \alpha_t} \epsilon)}{\sqrt{\alpha_t}} , \sigma_t^2 \mathbf{I}\right)

\sigma_t = \sqrt{\beta_t}时,退化为 DDPM 的马尔可夫采样;当 \sigma_t = 0时,采样过程完全确定。

2. 变分下界简化

DDIM 的训练目标与 DDPM 一致,都是最小化噪声预测损失,但 DIMM 通过非马尔可夫设计简化了变分下界:

L_{\text{DDIM}} = \mathbb{E}_{q} \left[ \sum_{t=2}^{T} \omega_t \cdot D_{\text{KL}}\left(q(x_{t-1} \mid x_t, x_0) \parallel p_\theta(x_{t-1} \mid x_t)\right) \right]

其中权重 \omega*t\sigma_t决定,且省略了 DDPM 中与边界条件相关的项(如 D_{\text{KL}}(q(x_T \mid x_0) \parallel p(x_T)),大幅降低计算复杂度。D_{\text{KL}}则是两个高斯分布的 KL 散度。

四、模型结构

DDIM 沿用 DDPM 的 U-Net 架构作为主干网络,包含对称的编码器-解码器路径和跳跃连接,但针对采样效率进行了轻量化调整:

1. 网络设计细节
  • 归一化与激活:使用GroupNorm替代 BatchNorm 用以提升小批量训练稳定性,使用SiLU 激活函数替代 ReLU,增强非线性建模能力;
  • 时间嵌入:将时间步 t编码为高维向量(如正弦编码或可学习嵌入),通过线性层与各层特征融合;
  • 跳跃连接:保留原来的编码器-解码器的多尺度特征融合,确保细节恢复能力。
2. 关键模块对比
  • 采样层:DDIM 的p_sample方法通过判断σ_t是否为 0,决定是否添加随机噪声,默认σ_t=0 时为纯确定性计算;
  • 时间步处理:支持任意时间步跳转,无需按顺序遍历,通过预设的时间步列表(如[T_s, T_{s-1}, ..., T_2, T_1])实现跳步采样。

五、代码实现

# 核心采样逻辑
class DDIM(nn.Cell):
    """DDIM核心类,实现跳跃式确定性采样"""
    def __init__(self, model, betas, T=1000, sample_steps=50):
        super().__init__()
        self.model = model  # U-Net网络
        self.T = T          # 总时间步
        self.sample_steps = sample_steps  # 采样时使用的跳步步长
        self.betas = betas
        self.alphas = 1. - betas
        self.alpha_bars = np.cumprod(self.alphas)

        # 生成跳步时间序列(如从T到0,每隔T/sample_steps步取一个点)
        self.sampling_timesteps = np.linspace(0, T-1, sample_steps, dtype=np.int64)[::-1]

    def p_sample(self, x, t):
        """确定性去噪单步(σ=0)"""
        alpha = self.alphas[t]
        alpha_bar = self.alpha_bars[t]
        sqrt_alpha = ops.sqrt(alpha)
        sqrt_one_minus_alpha = ops.sqrt(1 - alpha)

        # 预测噪声并估计原始数据
        pred_noise = self.model(x, t)
        pred_x0 = (x - sqrt_one_minus_alpha * pred_noise) / sqrt_alpha

        # DDIM确定性采样公式
        alpha_bar_prev = self.alpha_bars[t-1] if t > 0 else 1.0
        sqrt_alpha_bar_prev = ops.sqrt(alpha_bar_prev)
        sqrt_one_minus_alpha_bar_prev = ops.sqrt(1 - alpha_bar_prev)

        x_prev = sqrt_alpha_bar_prev * pred_x0 + sqrt_one_minus_alpha_bar_prev * pred_noise
        return x_prev

    def construct(self, x):
        """跳步采样过程(从x_T到x_0)"""
        for t in self.sampling_timesteps:
            x = self.p_sample(x, t)
        return x

# U-Net 改进
class UNet(nn.Cell):
    """带GroupNorm和SiLU的轻量化U-Net"""
    def __init__(self, in_channels=3, channel_dim=128):
        super().__init__()
        self.time_embed = nn.SequentialCell(
            nn.Embedding(1000, channel_dim),
            nn.SiLU(),
            nn.Dense(channel_dim, channel_dim * 4)
        )

        self.down = nn.SequentialCell(
            nn.Conv2d(in_channels, channel_dim, 3, padding=1),
            nn.GroupNorm(32, channel_dim),
            nn.SiLU(),
            nn.Conv2d(channel_dim, channel_dim * 2, 3, padding=1, stride=2),
            nn.GroupNorm(32, channel_dim * 2),
            nn.SiLU()
        )

        self.up = nn.SequentialCell(
            nn.Conv2dTranspose(channel_dim * 2, channel_dim, 3, stride=2, padding=1),
            nn.GroupNorm(32, channel_dim),
            nn.SiLU(),
            nn.Conv2d(channel_dim, in_channels, 3, padding=1),
            nn.Tanh()
        )

    def construct(self, x, t):
        t_emb = self.time_embed(t)
        h = self.down(x) + t_emb.view(-1, h.shape[1], 1, 1)
        return self.up(h)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值