本文探索了一类新的基于 Transformer 的扩散模型 Diffusion Transformers (DiTs)。本文训练 latent diffusion models 时,使用 Transformer 架构替换常用的 UNet 架构,且 Transformer 作用于 latent patches 上。

本文探索了一类新的基于 Transformer 的扩散模型 Diffusion Transformers (DiTs)。本文训练 latent diffusion models 时,使用 Transformer 架构替换常用的 UNet 架构,且 Transformer 作用于 latent patches 上。

作者探索了 DiT 的缩放性,发现具有较高 GFLOPs 的 DiT 模型,通过增加 Transformer 宽度或者深度或者输入 token 数量,始终有更好的 FID 值。最大的 DiT-XL/2 模型在 ImageNet 512×512 和 256×256 的测试中优于所有先前的扩散模型,实现了 2.27 的 FID 值。

做了什么工作

  1. 探索了一类新的基于 Transformer 的 Diffusion Model,称为 Diffusion Transformers (DiTs)。
  2. 研究了 DiT 对于模型复杂度 (GFLOPs) 和样本质量 (FID) 的缩放性。
  3. 证明了通过使用 Latent Diffusion Models (LDMs)[1]框架,Diffusion Model 中的 U-Net 架构可以被 Transformer 替换。

1 DiT:Transformer 构建扩散模型

论文名称:Scalable Diffusion Models with Transformers (ICCV 2023, Oral)

论文地址:https//arxiv.org/pdf/2212.09748.pdf

论文主页:https//www.wpeebles.com/DiT.html

1 DiT 论文解读:

1.1 把 Transformer 引入 Diffusion Models

机器学习正经历着 Transformer 架构带来的复兴:NLP,CV 等许多领域正在被 Transformer 模型覆盖。尽管 Transformer 在 Autoregressive Model 中得到广泛应用[2][3][4][5],但是这种架构在生成式模型中较少采用。比如,作为图像领域生成模型的经典方法,Diffusion Models[6][7]却一直使用基于卷积的 U-Net 架构作为骨干网络。

Diffusion Models 的开创性工作 DDPM [8]首次引入了基于 U-Net 骨干网络的扩散模型。U-Net 继承自 PixelCNN++[9][10],变化很少。与标准 U-Net[11]相比,额外的空间 Self-Attention 块 (Transformer 中必不可少的组件) 以较低分辨率穿插。[12]这个工作探索了 U-Net 的几种架构选择,例如自适应归一化层 (Adaptive Normalization Layer[13]为卷积层注入条件信息和通道计数。然而,DDPM 里面 U-Net 的高级设计在很大程度上都保持不变。

本文的目的是探索 Diffusion Models 架构选择的重要性,并为未来生成式模型的研究提供基线。本文的结论表明 U-Net 架构设计对 Diffusion Models 的性能并不重要,并且它们可以很容易地替换为 Transformers。

本文证明了 Diffusion Models 也可以受益于 Transformer 架构,受益于其训练方案,受益于其可扩展性,受益于其鲁棒性和效率等等。标准化架构还将为跨域研究开辟了新的可能性。

1.2 Diffusion Models 简介

DDPM

高斯扩散模型假设有一个前向的加噪过程 (Forward Noising Process),在这个过程中逐渐将噪声应用于真实数据:

Diffusion Transformers (DiTs)_缩放

这个优化的目标函数比较复杂,最后通过 variational lower bound 方法得到的结论是优化下式 (此处详细推导可以参考开创性工作 DDPM[8]):

Diffusion Transformers (DiTs)_架构设计_02

1.3 DiT 架构介绍

1.3.1 Patchify 过程

Diffusion Transformers (DiTs)_架构设计_03

图1:图片的 Patchify 操作。当 Patch 的大小 p 越小时,token 的数量 T 越大

1.3.2 DiT Block 设计

在 Patchify 之后,输入的 tokens 开始进入一系列 Transformer Block 中。除了噪声图像输入之外,Diffusion Model 有时会处理额外的条件信息,比如噪声时间步长 ttt , 类标签 ccc , 自然语言。

作者探索了4种不同类型的 Transformer Block,以不同的方式处理条件输入。这些设计都对标准 ViT Block 进行了微小的修改,所有 Block 的设计如下图2所示。

Diffusion Transformers (DiTs)_架构设计_04

图2:Diffusion Transformer (DiT) 架构

  • In-Context Conditioning

Diffusion Transformers (DiTs)_架构设计_05

作者将以上几种方法 In-Context Conditioning,Cross-Attention Block,Adaptive Layer Norm (adaLN) Block,adaLN-Zero Block 的做法列入 DiT 的设计空间中。

1.3.3 模型尺寸

Diffusion Transformers (DiTs)_架构设计_06

图3:DiT 模型的详细配置

作者将以上几种配置列入了 DiT 的设计空间中。

1.3.4 Transformer Decoder

在最后一个 DiT Block 之后,需要将 image tokens 的序列解码为输出噪声以及对角的协方差矩阵的预测结果。

Diffusion Transformers (DiTs)_迭代_07

最终,完整 DiT 的设计空间是 Patch Size、DiT Block 的架构和模型大小。

1.4 DiT 训练策略

1.4.1 训练配方

作者在 ImageNet 数据集上训练了 class-conditional latent DiT 模型,标准的实验设置。

Diffusion Transformers (DiTs)_人工智能_08

数据增强技术只使用 horizontal flips。

作者发现 learning rate warmup 和 regularization,对训练 DiT 模型而言不是必须的。

作者使用了 exponential moving average (EMA),参数为 0.9999 。

训练超参数基本都来自 ADM,不调学习率, decay/warm-up schedules, Adam 参数以及 weight decay.

1.4.2 扩散模型配置

Diffusion Transformers (DiTs)_架构设计_09

作者保留了 ADM 中使用的超参数。   

1.5.1 DiT 架构设计

作者首先探索的是不同 Conditioning 策略的对比。对于一个 DiT-XL/2 模型,其计算复杂度分别是:in-context (119.4 Gflops), cross-attention (137.6 Gflops), adaptive layer norm (adaLN, 118.6 Gflops), adaLN-zero (118.6 Gflops)。实验结果如下图4所示。

adaLN-Zero 的 Block 架构设计取得了最低的 FID 结果,同时在计算量上也是最高效的。在 400K 训练迭代中,adaLN-Zero Block 架构得到的 FID 几乎是 In-Context 的一半,表明 Condition 策略会严重影响模型的质量。

初始化同样也重要:adaLN-Zero Block 架构在初始化时相当于恒等映射,其性能也大大优于 adaLN Block 架构。

因此,在后续实验中,DiT 将一直使用 adaLN-Zero Block 架构。

Diffusion Transformers (DiTs)_架构设计_10

图4:不同 Conditioning 策略对比

1.5.2 缩放模型尺寸和 Patch Size

作者训练了12个 DiT 模型 (尺寸为 S, B, L, XL,Patch Size 为 8,4,2)。下图是不同 DiT 模型的尺寸和 FID-50K 性能。如下图5所示是不同大小 DiT 模型的 GFLOPs 以及在 400K 训练迭代中的 FID 值。可以发现,在增加模型大小或者减小 Batch Size 时可以显著改善 DiT 的性能。

Diffusion Transformers (DiTs)_迭代_11

图5:不同尺寸 DiT 模型的 GFLOPs 以及它们在 400K 训练迭代中的 FID

下图6上方是 Patch Size 不变,增加模型规模时 FID 的变化。当模型变深变宽时,FID 会下降。

下方是模型规模不变,减小 Patch Size 时 FID 的变化。当 Patch Size 下降时,FID 出现显著改善。

Diffusion Transformers (DiTs)_缩放_12

图5:缩放 DiT 模型可以改善训练各个阶段的 FID

1.5.3 GFLOPs 对性能很重要

上图5的结果表明,参数量并不能唯一确定 DiT 模型的质量。当 Patch Size 减小时,参数量仅仅是略有下降,只有 GFLOPs 明显增加。这些结果都表明了缩放模型的 GFLOPs 才是性能提升的关键。为了印证这一点,作者在下图6中绘制了不同 GFLOPs 模型在 400K 训练步骤时候的 FID-50K 结果。这些结果表明,当不同 DiT 模型的总 GFLOPs 相似时,它们的 FID 值也相似,比如 DiT-S/2 和 DiT-B/4。

作者还发现 DiT 模型的 GFLOPs 和 FID-50K 之间存在很强的负相关关系。

Diffusion Transformers (DiTs)_迭代_13

图6:GFLOPs 与 FID 密切相关

1.5.4 大模型更加计算高效

Diffusion Transformers (DiTs)_架构设计_14

图7:大模型更加计算高效

1.5.5 缩放结果可视化

Diffusion Transformers (DiTs)_架构设计_15

Diffusion Transformers (DiTs)_缩放_16

图8:缩放对于视觉质量的影响

1.6 DiT 实验结果

作者将 DiT 与最先进的生成模型进行了比较,结果如图9所示。DiT-XL/2 优于所有先前的扩散模型,将 LDM 实现的先前最佳 FID-50K 降低到 2.27。图5右侧显示 DiT-XL/2 (118.6 GFLOPs) 相对于 LDM-4 (103.6 GFLOPs) 等Latent Space U-Net 模型的计算效率很高,并且比 Pixel Space U-Net 模型更高效,例如 ADM (1120 GFLOPs) 或 ADM-U (742 GFLOPs)。

Diffusion Transformers (DiTs)_架构设计_17

图9:ImageNet 256×256 图像生成结果

作者在 ImageNet 上训练了一个新的 DiT-XL/2,这次分辨率是 512×512,3M training iterations,超参数与 256×256 模型相同。这个模型 latent 的维度是 64×64×4,然后 Patch Size 为2,这样 Transformer 模型需要处理的 token 的数量就是 1024。如下图10所示是比较结果。DiT-XL/2 在此分辨率下再次优于所有先前的扩散模型,将 ADM 实现的先前最佳 FID 提高了 3.85 到 3.04。即使 token 的数量增加了,DiT-XL/2 的计算效率依然很高,比如 ADM 使用 1983 GFLOPs,ADM-U 使用 2813 GFLOPs,DiT-XL/2 仅仅使用 524.6 GFLOPs。

Diffusion Transformers (DiTs)_缩放_18

图10:ImageNet 512×512 图像生成结果

缩放模型大小还是采样次数?

Diffusion Model 的一个独特之处是它们可以通过在生成图像时增加采样步骤的数量来在训练期间使用额外的计算。也就是扩散模型的计算量既可以来自模型本身的缩放,也可以来自采样次数的增加。因此,作者在这里研究了通过使用更多的采样计算,较小的 DiT 模型是否可以胜过更大的模型。

作者计算了所有的 12 个 DiT 模型在 400K training iteration 时候的 FID 值,每张图分别使用 [16, 32, 64, 128, 256, 1000] sampling steps。

实验结果如下图11所示,考虑使用 1000 个采样步骤的 DiT-L/2 和使用 128 步的 DiT-XL/2。在这种情况下:

  • DiT-L/2 使用 80.7 TFLOPs 对每张图像进行采样。
  • DiT-XL/2 使用 15.2 TFLOPs 对每张图像进行采样。

但尽管如此,DiT-XL/2 具有更好的 FID-10K 结果。说明增加采样的计算量也无法弥补模型本身计算量的缺失。

Diffusion Transformers (DiTs)_迭代_19

图11:增加采样的计算量也无法弥补模型本身计算量的缺失