【笔记】扩散模型(二):DDIM 理论推导与代码实现

论文链接:Denoising Diffusion Implicit Models

在正式开始之前先贴一个生成结果示例:
语义插值效应

上一篇文章中我们进行了 DDPM 的理论推导,并且自己编写代码实现了 DDPM 的训练和采样过程。虽然取得了还不错的效果,但 DDPM 有一个非常明显的问题:采样过程很慢。因为 DDPM 的反向过程利用了马尔可夫假设,所以每次都必须在相邻的时间步之间进行去噪,而不能跳过中间步骤。原始论文使用了 1000 个时间步,所以我们在采样时也需要循环 1000 次去噪过程,这个过程是非常慢的。

为了加速 DDPM 的采样过程,DDIM 在不利用马尔可夫假设的情况下推导出了 diffusion 的反向过程,最终可以实现仅采样 20~100 步的情况下达到和 DDPM 采样 1000 步相近的生成效果,也就是提速 10~50 倍。这篇文章将对 DDIM 的理论进行讲解,并实现 DDIM 采样的代码。

DDPM 的反向过程

首先我们回顾一下 DDPM 反向过程的推导,为了推导出 q ( x t − 1 ∣ x t ) q(\mathbf{x}_{t-1}|\mathbf{x}_t) q(xt1xt) 这个条件概率分布,DDPM 利用贝叶斯公式将其变成了先验分布的组合,并且通过向条件中加入 x 0 \mathbf{x}_0 x0 将所有的分布转换为已知分布:
q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) q(\mathbf{x}_{t-1}|\mathbf{x}_t,\mathbf{x}_0)=\frac{q(\mathbf{x}_t|\mathbf{x}_{t-1},\mathbf{x}_0)q(\mathbf{x}_{t-1}|\mathbf{x}_0)}{q(\mathbf{x}_t|\mathbf{x}_0)} q(xt1xt,x0)=q(xtx0)q(xtxt1,x0)q(xt1x0)
在上边这个等式的右侧, q ( x t − 1 ∣ x 0 ) q(\mathbf{x}_{t-1}|\mathbf{x}_0) q(xt1x0) q ( x t ∣ x 0 ) q(\mathbf{x}_t|\mathbf{x}_0) q(xtx0) 都是已知的,需要求解的只有 q ( x t ∣ x t − 1 , x 0 ) q(\mathbf{x}_t|\mathbf{x}_{t-1},\mathbf{x}_0) q(xtxt1,x0)。在这里 DDPM 引入马尔可夫假设,认为 x t \mathbf{x}_t xt 只与 x t − 1 \mathbf{x}_{t-1} xt1 有关,将其转化成了 q ( x t ∣ x t − 1 ) q(\mathbf{x}_t|\mathbf{x}_{t-1}) q(xtxt1)。最后经过推导,得出条件概率分布:
q ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , σ t 2 I ) q(\mathbf{x}_{t-1}|\mathbf{x}_t)=\mathcal{N}(\mathbf{x}_{t-1};\mu_\theta(\mathbf{x}_t,t),\sigma_t^2\mathbf{I}) q(xt1xt)=N(xt1;μθ(xt,t),σt2I)
我们可以看到之所以 DDPM 很慢,就是因为在推导 q ( x t ∣ x t − 1 , x 0 ) q(\mathbf{x}_t|\mathbf{x}_{t-1},\mathbf{x}_0) q(xtxt1,x0) 的时候引入了马尔可夫假设,使得去噪只能在相邻时间步之间进行。如果我们可以在不依赖马尔可夫假设的情况下推导出 q ( x t − 1 ∣ x t , x 0 ) q(\mathbf{x}_{t-1}|\mathbf{x}_t,\mathbf{x}_0) q(xt1xt,x0),就可以将上面式子里的 t − 1 t-1 t1 替换为任意的中间时间步 τ \tau τ,从而实现采样加速。总结来说,DDIM 主要有两个出发点:

  1. 保持前向过程的分布 q ( x t ∣ x t − 1 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) q(\mathbf{x}_t|\mathbf{x}_{t-1})=\mathcal{N}\left(\mathbf{x}_t;\sqrt{\bar{\alpha}_t}\mathbf{x}_0,(1-\bar{\alpha}_t)\mathbf{I}\right) q(xtxt1)=N(xt;αˉt x0,(1αˉt)I) 不变;
  2. 构建一个不依赖于马尔可夫假设的 q ( x τ ∣ x t , x 0 ) q(\mathbf{x}_\tau|\mathbf{x}_t,\mathbf{x}_0) q(xτxt,x0) 分布。

q ( x τ ∣ x t , x 0 ) q(\mathbf{x}_\tau|\mathbf{x}_t,\mathbf{x}_0) q(xτxt,x0) 的推导

开始推导之前简单说明一下,这个 q ( x τ ∣ x t , x 0 ) q(\mathbf{x}_\tau|\mathbf{x}_t,\mathbf{x}_0) q(xτxt,x0) 实际上就是上一章中提到的 q ( x t − 1 ∣ x t , x 0 ) q(\mathbf{x}_{t-1}|\mathbf{x}_t,\mathbf{x}_0) q(xt1xt,x0),只不过是因为我们的推导不再依赖马尔可夫假设,所以 t − 1 t-1 t1 可以替换为任意的 τ ∈ ( 0 , t ) \tau\in(0,t) τ(0,t)。为了避免混淆,我们在这里使用一个通用的符号 τ ∈ ( 0 , t ) \tau\in(0,t) τ(0,t) 表示中间的时间步。

另一点需要说明的是,在 DDIM 的论文中, α \alpha α 表示的含义和 DDPM 论文中的 α ˉ \bar{\alpha} αˉ 相同。为了保证前后一致,我们在这里依然使用 DDPM 的符号约定,令 α t = 1 − β t \alpha_t=1-\beta_t αt=1βt α ˉ t = ∏ i = 1 t α i \bar{\alpha}_t=\prod_{i=1}^t\alpha_i αˉt=i=1tαi

我们在 DDPM 里已经推导出了 q ( x t − 1 ∣ x t , x 0 ) q(\mathbf{x}_{t-1}|\mathbf{x}_t,\mathbf{x}_0) q(xt1xt,x0) 是一个高斯分布,均值和方差为:
μ = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ t x 0 σ = ( α t β t + 1 1 − α ˉ t − 1 ) − 1 / 2 \begin{aligned} \mu&=\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_t}\mathbf{x}_t+\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}\mathbf{x}_0\\ \sigma&=\left(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar\alpha_{t-1}}\right)^{-1/2} \end{aligned} μσ=1αˉtαt (1αˉt1)xt+1αˉtαˉt1 βtx0=(βtαt+1αˉt11)1/2
可以看到均值是 x 0 \mathbf{x}_0 x0 x t \mathbf{x}_t xt 的线性组合,方差是时间步的函数。DDIM 基于这样的规律,使用待定系数法:
q ( x τ ∣ x t , x 0 ) = N ( x τ ; λ x 0 + k x t , σ t 2 I ) q(\mathbf{x}_\tau|\mathbf{x}_t,\mathbf{x}_0)=\mathcal{N}(\mathbf{x}_\tau;\lambda\mathbf{x}_0+k\mathbf{x}_t,\sigma_t^2\mathbf{I}) q(xτxt,x0)=N(xτ;λx0+kxt,σt2I)
也就是 x τ = λ x 0 + k x t + σ t ϵ τ \mathbf{x}_\tau=\lambda\mathbf{x}_0+k\mathbf{x}_t+\sigma_t\epsilon_\tau xτ=λx0+kxt+σtϵτ。又因为前向过程满足 x t = α ˉ t x 0 + 1 − α ˉ t ϵ t \mathbf{x}_t=\sqrt{\bar{\alpha}_t}\mathbf{x}_0+\sqrt{1-\bar{\alpha}_t}\epsilon_t xt=αˉt x0+1αˉt ϵt,代入可以得到:
x τ = λ x 0 + k x t + σ t ϵ τ = λ x 0 + k ( α ˉ t x 0 + 1 − α ˉ t ϵ t ) + σ t ϵ τ = ( λ + k α ˉ t ) x 0 + ( k 1 − α ˉ t ϵ t + σ t ϵ τ ) = ( λ + k α ˉ t ) x 0 + k 2 ( 1 − α ˉ t ) + σ t 2 ϵ \begin{aligned} \mathbf{x}_\tau&=\lambda\mathbf{x}_0+k\mathbf{x}_t+\sigma_t\epsilon_\tau\\ &=\lambda\mathbf{x}_0+k(\sqrt{\bar{\alpha}_t}\mathbf{x}_0+\sqrt{1-\bar{\alpha}_t}\epsilon_t)+\sigma_t\epsilon_\tau\\ &=(\lambda+k\sqrt{\bar{\alpha}_t})\mathbf{x}_0+(k\sqrt{1-\bar{\alpha}_t}\epsilon_t+\sigma_t\epsilon_\tau)\\ &=(\lambda+k\sqrt{\bar{\alpha}_t})\mathbf{x}_0+\sqrt{k^2(1-\bar{\alpha}_t)+\sigma_t^2}\epsilon \end{aligned} xτ=λx0+kxt+σtϵτ=λx0+k(αˉt x0+1αˉt ϵt)+σtϵτ=(λ+kαˉt )x0+(k1αˉt ϵt+σtϵτ)=(λ+kαˉt )x0+k2(1αˉt)+σt2 ϵ
在上面的推导过程中,由于 ϵ t \epsilon_t ϵt ϵ τ \epsilon_\tau ϵτ 都满足标准正态分布,因此两项可以合并。又因为根据前向过程,有 x τ = α ˉ τ x 0 + 1 − α ˉ τ ϵ τ \mathbf{x}_\tau=\sqrt{\bar{\alpha}_\tau}\mathbf{x}_0+\sqrt{1-\bar{\alpha}_\tau}\epsilon_\tau xτ=αˉτ x0+1αˉτ ϵτ,将两个式子的系数对比,可以得到方程组:
{ λ + k α ˉ t = α ˉ τ k 2 ( 1 − α ˉ t ) + σ t 2 = 1 − α ˉ τ \begin{cases} \begin{aligned} \lambda+k\sqrt{\bar{\alpha}_t}&=\sqrt{\bar{\alpha}_\tau}\\ \sqrt{k^2(1-\bar{\alpha}_t)+\sigma_t^2}&=\sqrt{1-\bar{\alpha}_\tau} \end{aligned} \end{cases} λ+kαˉt k2(1αˉt)+σt2 =αˉτ =1αˉτ
解方程组得到 λ \lambda λ k k k
{ λ = α ˉ τ − ( 1 − α ˉ τ − σ t 2 ) α ˉ t 1 − α ˉ t k = 1 − α ˉ τ − σ t 2 1 − α ˉ t \begin{cases} \begin{aligned} \lambda&=\sqrt{\bar{\alpha}_\tau}-\sqrt{\frac{(1-\bar{\alpha}_\tau-\sigma_t^2)\bar{\alpha}_t}{1-\bar{\alpha}_t}}\\ k&=\sqrt{\frac{1-\bar{\alpha}_\tau-\sigma_t^2}{1-\bar{\alpha}_t}} \end{aligned} \end{cases} λk=αˉτ 1αˉt(1αˉτσt2)αˉt =1αˉt1αˉτσt2
在上边的结果中,我们得到了 q ( x τ ∣ x t , x 0 ) q(\mathbf{x}_\tau|\mathbf{x}_t,\mathbf{x}_0) q(xτxt,x0) 均值中的两个参数,而方差 σ t 2 \sigma_t^2 σt2 并没有唯一定值,因此这个结果对应于一组解,通过规定不同的方差,可以得到不同的采样过程。我们把 x 0 \mathbf{x}_0 x0 x t \mathbf{x}_t xt 替换,可以得到均值的表达式:
μ = λ x 0 + k x t = ( α ˉ τ − ( 1 − α ˉ τ − σ t 2 ) α ˉ t 1 − α ˉ t ) x 0 + 1 − α ˉ τ − σ t 2 1 − α ˉ t x t = ( α ˉ τ − ( 1 − α ˉ τ − σ t 2 ) α ˉ t 1 − α ˉ t ) ( x t − 1 − α ˉ t ϵ θ ( x t , t ) α ˉ t ) + 1 − α ˉ τ − σ t 2 1 − α ˉ t x t = α ˉ τ x t − 1 − α ˉ t ϵ θ ( x t , t ) α ˉ t + 1 − α ˉ τ − σ t 2 ϵ θ ( x t , t ) \begin{aligned} \mu&=\lambda\mathbf{x}_0+k\mathbf{x}_t\\ &=\left(\sqrt{\bar{\alpha}_\tau}-\sqrt{\frac{(1-\bar{\alpha}_\tau-\sigma_t^2)\bar{\alpha}_t}{1-\bar{\alpha}_t}}\right)\mathbf{x}_0+\sqrt{\frac{1-\bar{\alpha}_\tau-\sigma_t^2}{1-\bar{\alpha}_t}}\mathbf{x}_t\\ &=\left(\sqrt{\bar{\alpha}_\tau}-\sqrt{\frac{(1-\bar{\alpha}_\tau-\sigma_t^2)\bar{\alpha}_t}{1-\bar{\alpha}_t}}\right)\left(\frac{\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t}\epsilon_\theta(\mathbf{x}_t,t)}{\sqrt{\bar{\alpha}_t}}\right)+\sqrt{\frac{1-\bar{\alpha}_\tau-\sigma_t^2}{1-\bar{\alpha}_t}}\mathbf{x}_t\\ &=\sqrt{\bar{\alpha}_\tau}\frac{\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t}\epsilon_\theta(\mathbf{x}_t,t)}{\sqrt{\bar{\alpha}_t}}+\sqrt{1-\bar{\alpha}_\tau-\sigma_t^2}\epsilon_\theta(\mathbf{x}_t,t) \end{aligned} μ=λx0+kxt= αˉτ 1αˉt(1αˉτσt2)αˉt x0+1αˉt1αˉτσt2 xt= αˉτ 1αˉt(1αˉτσt2)αˉt (αˉt xt1αˉt ϵθ(xt,t))+1αˉt1αˉτσt2 xt=αˉτ αˉt xt1αˉt ϵθ(xt,t)+1αˉτσt2 ϵθ(xt,t)
因此我们可以得到最终的 x τ \mathbf{x}_\tau xτ 的表达式:
x τ = μ + σ t ϵ = α ˉ τ x t − 1 − α ˉ t ϵ θ ( x t , t ) α ˉ t ⏟ 预测的 x 0 + 1 − α ˉ τ − σ t 2 ϵ θ ( x t , t ) ⏟ 指向 x t 的方向 + σ t ϵ ⏟ 随机的噪声 \begin{aligned} \mathbf{x}_\tau&=\mu+\sigma_t\epsilon\\ &=\sqrt{\bar{\alpha}_\tau}\underbrace{\frac{\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t}\epsilon_\theta(\mathbf{x}_t,t)}{\sqrt{\bar{\alpha}_t}}}_{预测的\mathbf{x}_0}+\underbrace{\sqrt{1-\bar{\alpha}_\tau-\sigma_t^2}\epsilon_\theta(\mathbf{x}_t,t)}_{指向\mathbf{x}_t的方向}+\underbrace{\sigma_t\epsilon}_{随机的噪声} \end{aligned} xτ=μ+σtϵ=αˉτ 预测的x0 αˉt xt1αˉt ϵθ(xt,t)+指向xt的方向 1αˉτσt2 ϵθ(xt,t)+随机的噪声 σtϵ

方差的取值

正如我们前文中所说,我们得到的实际上是 x τ \mathbf{x}_\tau xτ 的一组解,其中的 σ t \sigma_t σt 并没有固定的取值。在论文中,作者参照 DDPM 的方差的形式给出了一个 σ t \sigma_t σt 的形式:
σ t = η 1 − α ˉ τ 1 − α ˉ t 1 − α t \sigma_t=\eta\sqrt{\frac{1-\bar{\alpha}_\tau}{1-\bar{\alpha}_t}}\sqrt{1-\alpha_t} σt=η1αˉt1αˉτ 1αt

  1. η = 1 \eta=1 η=1,生成过程与 DDPM 一致。这个感觉还是可以理解的,因为在待定系数法求解时,本身就是假定均值的形式和 DDPM 相同,如果再假定方差和 DDPM 相同,那么最后的整体形式也会变成 DDPM。
  2. η = 0 \eta=0 η=0,此时生成过程不再添加随机噪声项,唯一带有随机性的因素就是采样初始的 x T ∼ N ( 0 , 1 ) \mathbf{x}_T\sim\mathcal{N}(0,1) xTN(0,1),因此采样的过程是确定的,每个 x T \mathbf{x}_T xT 对应唯一的 x 0 \mathbf{x}_0 x0,这个模型就是 DDIM。

采样加速

我们知道 DDIM 的反向过程并不依赖于马尔可夫假设,因此去噪的过程并不需要在相邻的时间步之间进行,也就是跳过一些中间的步骤。形式化地来说,DDPM 的采样时间步应当是 [ T , T − 1 , . . . , 2 , 1 ] [T,T-1,...,2,1] [T,T1,...,2,1],而 DDIM 可以直接从其中抽取一个子序列 [ τ S , τ S − 1 , . . . , τ 2 , τ 1 ] [\tau_S,\tau_{S-1},...,\tau_2,\tau_1] [τS,τS1,...,τ2,τ1] 进行采样。

在 DDIM 论文的附录中,给出了两种子序列的选取方式:

  • 线性选取:令 τ i = ⌊ c i ⌋ \tau_i=\lfloor ci\rfloor τi=ci
  • 二次方选取:令 τ i = ⌊ c i 2 ⌋ \tau_i=\lfloor ci^2\rfloor τi=ci2

其中 c c c 是一个常量,制定这个常量的规则是让 τ − 1 \tau_{-1} τ1 也就是最后一个采样时间步尽可能与 T T T 接近。在原文的实验中,CIFAR10 使用的是二次方选取,其他数据集都使用的是线性选取方式。

DDIM 区别于 DDPM 的两个特性

  1. 采样一致性:我们知道 DDIM 的采样过程是确定的,生成结果只受 x T \mathbf{x}_T xT 影响。作者经过实验发现对于同一个 x T \mathbf{x}_T xT,使用不同的采样过程,最终生成的 x 0 \mathbf{x}_0 x0 比较相近,因此 x T \mathbf{x}_T xT 在一定程度上可以看作 x 0 \mathbf{x}_0 x0 的一种嵌入。

    因为这个性质的存在,在生成图像时也有一个 trick。也就是一开始先选取一个较小的时间步数量生成比较粗糙的图像,如果大致样子符合预期,再使用大时间步数量进行精细生成。

  2. 语义插值效应:根据上一条性质, x T \mathbf{x}_T xT 可以看作 x 0 \mathbf{x}_0 x0 的嵌入,那么它可能也具有其他隐概率模型所具有的语义差值效应。作者首先选取两个隐变量 x T ( 0 ) \mathbf{x}_T^{(0)} xT(0) x T ( 1 ) \mathbf{x}_T^{(1)} xT(1),对其分别采样得到结果,然后使用球面线性插值得到一系列中间隐变量,这个插值定义为:
    x T ( α ) = sin ⁡ ( 1 − α ) θ sin ⁡ θ x T ( 0 ) + sin ⁡ α θ sin ⁡ θ x T ( 1 ) \mathbf{x}_T^{(\alpha)}=\frac{\sin(1-\alpha)\theta}{\sin\theta}\mathbf{x}_T^{(0)}+\frac{\sin\alpha\theta}{\sin\theta}\mathbf{x}_T^{(1)} xT(α)=sinθsin(1α)θxT(0)+sinθsinαθxT(1)
    其中 θ = arccos ⁡ ( ( x T ( 0 ) ) T x T ( 1 ) ∣ ∣ x T ( 0 ) ∣ ∣   ∣ ∣ x T ( 1 ) ∣ ∣ ) \theta=\arccos\left(\frac{(\mathbf{x}_T^{(0)})^T\mathbf{x}_T^{(1)}}{||\mathbf{x}_T^{(0)}||~||\mathbf{x}_T^{(1)}||}\right) θ=arccos(∣∣xT(0)∣∣ ∣∣xT(1)∣∣(xT(0))TxT(1))。最终也在 DDIM 上观察到了语义插值效应,我们下面也将复现这一实验。

DDIM 的代码实现

从上面的推导过程可以发现,DDIM 假设的前向过程和 DDPM 相同,只有采样过程不同。因此想把 DDPM 改成 DDIM 并不需要重新训练,只要修改采样过程就可以了。在上一篇文章中我们已经训练好了一个 DDPM 模型,这里我们继续用这个训练好的模型来构造 DDIM 的采样过程。

如果你没有看上一篇文章,也可以直接在这个链接直接下载训练好的权重。

我们把训练好的 DDPM 模型的权重加载进来用作噪声预测网络:

from diffusers import UNet2DModel

model = UNet2DModel.from_pretrained('ddpm-anime-faces-64').cuda()

核心代码

首先我们依然是定义一系列常量, α \alpha α β \beta β 等都和 DDPM 相同,只有采样的时间步不同。我们在这里直接线性选取 20 个时间步,最大的为 999,最小的为 0:

import torch

class DDIM:
    def __init__(
        self,
        num_train_timesteps:int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        sample_steps: int = 20,
    ):
        self.num_train_timesteps = num_train_timesteps
        self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.timesteps = torch.linspace(num_train_timesteps - 1, 0, sample_steps).long()

然后是实现采样过程,和 DDPM 一样,我们把需要的公式复制到这里,然后对照着实现:
x τ = α ˉ τ x t − 1 − α ˉ t ϵ θ ( x t , t ) α ˉ t + 1 − α ˉ τ − σ t 2 ϵ θ ( x t , t ) + σ t ϵ σ t = η 1 − α ˉ τ 1 − α ˉ t 1 − α t \begin{aligned} \mathbf{x}_\tau&=\sqrt{\bar{\alpha}_\tau}\frac{\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t}\epsilon_\theta(\mathbf{x}_t,t)}{\sqrt{\bar{\alpha}_t}}+\sqrt{1-\bar{\alpha}_\tau-\sigma_t^2}\epsilon_\theta(\mathbf{x}_t,t)+\sigma_t\epsilon\\ \sigma_t&=\eta\sqrt{\frac{1-\bar{\alpha}_\tau}{1-\bar{\alpha}_t}}\sqrt{1-\alpha_t} \end{aligned} xτσt=αˉτ αˉt xt1αˉt ϵθ(xt,t)+1αˉτσt2 ϵθ(xt,t)+σtϵ=η1αˉt1αˉτ 1αt

import math
from tqdm import tqdm

class DDIM:
	...

    @torch.no_grad()
    def sample(
        self,
        unet: UNet2DModel,
        batch_size: int,
        in_channels: int,
        sample_size: int,
        eta: float = 0.0,
    ):
        alphas = self.alphas.to(unet.device)
        alphas_cumprod = self.alphas_cumprod.to(unet.device)
        timesteps = self.timesteps.to(unet.device)
        images = torch.randn((batch_size, in_channels, sample_size, sample_size), device=unet.device)
        for t, tau in tqdm(list(zip(timesteps[:-1], timesteps[1:])), desc='Sampling'):
            pred_noise: torch.Tensor = unet(images, t).sample

            # sigma_t
            if not math.isclose(eta, 0.0):
                one_minus_alpha_prod_tau = 1.0 - alphas_cumprod[tau]
                one_minus_alpha_prod_t = 1.0 - alphas_cumprod[t]
                one_minus_alpha_t = 1.0 - alphas[t]
                sigma_t = eta * (one_minus_alpha_prod_tau * one_minus_alpha_t / one_minus_alpha_prod_t) ** 0.5
            else:
                sigma_t = torch.zeros_like(alphas[0])

            # first term of x_tau
            alphas_cumprod_tau = alphas_cumprod[tau]
            sqrt_alphas_cumprod_tau = alphas_cumprod_tau ** 0.5
            alphas_cumprod_t = alphas_cumprod[t]
            sqrt_alphas_cumprod_t = alphas_cumprod_t ** 0.5
            sqrt_one_minus_alphas_cumprod_t = (1.0 - alphas_cumprod_t) ** 0.5
            first_term = sqrt_alphas_cumprod_tau * (images - sqrt_one_minus_alphas_cumprod_t * pred_noise) / sqrt_alphas_cumprod_t

            # second term of x_tau
            coeff = (1.0 - alphas_cumprod_tau - sigma_t ** 2) ** 0.5
            second_term = coeff * pred_noise

            epsilon = torch.randn_like(images)
            images = first_term + second_term + sigma_t * epsilon
        images = (images / 2.0 + 0.5).clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy()
        return images

上面的内容和 DDPM 大同小异,只有计算公式变了,应该没有太多坑,只要看清楚变量就可以了。最后我们执行采样过程:

ddim = DDIM()
images = ddim.sample(model, 32, 3, 64)

from diffusers.utils import make_image_grid, numpy_to_pil
image_grid = make_image_grid(numpy_to_pil(images), rows=4, cols=8)
image_grid.save('ddim-sample-results.png')

结果展示

采样速度的确是变快了很多,得到的结果如下图所示:

采样结果示例

感觉总体上采样效果比 DDPM 稍微有所下降,不过也还在可以接受的范围内,算是一种速度-质量的 tradeoff。

语义插值效应复现

语义插值效应也比较简单,只需要修改初始化的 x T \mathbf{x}_T xT​ 即可。根据上文的叙述,我们首先实现球面线性插值:
x T ( α ) = sin ⁡ ( 1 − α ) θ sin ⁡ θ x T ( 0 ) + sin ⁡ α θ sin ⁡ θ x T ( 1 ) ,    w h e r e   θ = arccos ⁡ ( ( x T ( 0 ) ) T x T ( 1 ) ∣ ∣ x T ( 0 ) ∣ ∣   ∣ ∣ x T ( 1 ) ∣ ∣ ) \mathbf{x}_T^{(\alpha)}=\frac{\sin(1-\alpha)\theta}{\sin\theta}\mathbf{x}_T^{(0)}+\frac{\sin\alpha\theta}{\sin\theta}\mathbf{x}_T^{(1)},~~\mathrm{where}~\theta=\arccos\left(\frac{(\mathbf{x}_T^{(0)})^T\mathbf{x}_T^{(1)}}{||\mathbf{x}_T^{(0)}||~||\mathbf{x}_T^{(1)}||}\right) xT(α)=sinθsin(1α)θxT(0)+sinθsinαθxT(1),  where θ=arccos(∣∣xT(0)∣∣ ∣∣xT(1)∣∣(xT(0))TxT(1))

import torch

def slerp(
    x0: torch.Tensor,
    x1: torch.Tensor,
    alpha: float,
):
    theta = torch.acos(torch.sum(x0 * x1) / (torch.norm(x0) * torch.norm(x1)))
    w0 = torch.sin((1.0 - alpha) * theta) / torch.sin(theta)
    w1 = torch.sin(alpha * theta) / torch.sin(theta)
    return w0 * x0 + w1 * x1

我们这次要实现的和原论文不同,原论文的插值只在一行内部,我们希望实现一个二维的插值,也就是在一个图片网格中,从左上角到右下角存在一个渐变效果。为此,我们需要先构建一个二维的图片网格,然后按以下的步骤完成二维插值:

  1. 初始化网格四角的 x T ∼ N ( 0 , 1 ) \mathbf{x}_T\sim\mathcal{N}(0,1) xTN(0,1)
  2. 在网格的最左侧和最右侧两列中进行插值,例如最左侧的一列由左上角与左下角两个样本插值得到、最右侧的一列由右上角与右下角的两个样本插值得到;
  3. 遍历所有行,把每行中间的元素用该行最左侧与最右侧的元素进行插值,完成全部 x T \mathbf{x}_T xT 的初始化。

具体的直接看代码就好:

def interpolation_grid(
    rows: int,
    cols: int,
    in_channels: int,
    sample_size: int,
):
    images = torch.zeros((rows * cols, in_channels, sample_size, sample_size), dtype=torch.float32)
    images[0, ...] = torch.randn_like(images[0, ...]) # top left
    images[cols - 1, ...] = torch.randn_like(images[0, ...]) # top right
    images[(rows - 1) * cols, ...] = torch.randn_like(images[0, ...]) # bottom left
    images[-1] = torch.randn_like(images[0, ...]) # bottom right
    for row in range(1, rows - 1): # interpolate left most column and right most column
        alpha = row / (rows - 1)
        images[row * cols, ...] = slerp(images[0, ...], images[(rows - 1) * cols, ...], alpha)
        images[(row + 1) * cols - 1, ...] = slerp(images[cols - 1, ...], images[-1, ...], alpha)
    for col in range(1, cols - 1): # interpolate others
        alpha = col / (cols - 1)
        images[col::cols, ...] = slerp(images[0::cols, ...], images[cols - 1::cols, ...], alpha)
    return images

最后把 images 的初始化从 torch.randn 改成调用 interpolation_grid

images = interpolation_grid(rows, cols, in_channels, sample_size).to(unet.device)

看一下结果如何:

语义插值效应

感觉还不错,那么 DDIM 的学习到这里就告一段落了。

总结

感觉 DDIM 还是非常神奇的,通过改变推导方式去除了对马尔可夫假设的依赖,而且最后表达式中几个复杂的项相互都可以消掉,最后得到一个比较优美的结果。而且最重要的是采样速度真的变快了好多,也因此我直接把实验从集群上搬到了我自己的 PC 上,的确很高效。

本文的代码在如下的链接中,后续还会更新更多 diffusion models 相关的文章,欢迎追更:

参考资料:

  1. diffusion model(二):DDIM技术小结 (denoising diffusion implicit model)
  2. 扩散模型(一)| DDPM & DDIM

本文原文以 CC BY-NC-SA 4.0 许可协议发布于 笔记|扩散模型(二):DDIM 理论与实现,转载请注明出处。

  • 22
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值