DiT Scalable Diffusion Models with Transformers

DiT: Scalable Diffusion Models with Transformers

TL; DR:将 Diffusion 中常用的 UNet 模型替换为 Transformer 模型,并探究了为 Transformer 模型注入条件的结构设计和 scaling 的能力。实验显示,随着模型计算复杂度的上升,生成质量也稳步提升,Transformer 在 Diffusion 中的 scaling 能力得到了验证。而 adaLN 相较于 in-context、cross-attention 等结构更适合为 Transformer 模型注入条件。


Diffusion 仅要求其去噪网络是一个输入输出等尺寸的图到图模型,基于 CNN 的 UNet 是一个很自然的选择。但近些年来,凭借着更强的全局理解能力和 scaling 能力,Transformer 在视觉领域也大放异彩。本文中作者使用 ViT 替换掉 UNet,在 Diffusion 中也验证了 Transformer 的 scaling 能力。并试验了 in-context、cross-attention、adaLN 等不同的条件注入方式。

Diffusion Transformer及其条件注入结构设计

patchify

我们知道 Transformer 处理的是一维的序列,而在(图片) Diffusion 中,我们处理的是二维的图片或者特征图,这里需要将特征图展平为一维的序列。为了指示patch 之间的位置关系,DiT 使用了 sin-cos 的位置编码,这个位置编码是不可学习的。

在这里插入图片描述

DiT block design

在 patchify 之后,就可以把各个 patch token 送到 Transformer 中进行处理,此时就已经可以做无条件生成了。如果要做有条件生成,还需要将条件编码并输入到模型中。在常用的 Classfier-free Guidance 中,就是要将条件 embedding 和非条件 embedding 输入到模型中。在 LDM 中,是通过 cross-attention 的方式输入到 UNet 模型中的。那么在 Transformer 结构中,如何把条件注入到模型中呢?

在这里插入图片描述

DiT 探究了四种对 Transformer 的条件注入方法:

  1. In-context conditioning:将两个 embedding 作为两个 special tokens 拼接到图像块 token 后,类似 ViT 中的 cls token,实现起来比较简单,基本没有额外的计算量。
  2. Cross-attention:将两个 embedding 拼接起来,然后在 transformer block 中插入一个 cross attention,将 embedding 作为 cross attention的 K 和 V;这也是,该方法需要引入的额外计算量最大,约增加 15%。
  3. Adaptive layer norm (adaLN):adaLN 在 GAN 这类生成模型中的应用非常防范。将常规的 LN 替换为 adaLN,回归 scale 和 shift 两个参数,这种方式也基本不增加计算量。
  4. adaLN-zero:即采用零初始化,将 adaLN 的线性层参数初始化为零,网络初始化时 transformer block 的残差模块就是一个 identity 函数。除了回归scale 和 shift,还在每个残差模块结束之前回归一个 scale。

实验对比这四种结构的性能,adaLN-zero 最优。

在这里插入图片描述

Transformer Decoder

在隐层进行 Diffusion 去噪之后,需要输出的还是特征图,再送到 vae decoder 中解码出真实图片。但 Transformer 的输出是一维序列,所以这里还需要一个 Transformer Decoder 来将一维序列变换为与输入同尺寸的特征图。这里直接用了一个标准的线性解码器。

Scaling 性能

之所以要用 Transformer 替换掉 UNet,很重要的一个原因就是 Transformer 的 scaling 能力更强,给更多的参数和计算量,就能有更好的性能。

关于 scaling 能力最关键的实验如下图所示。可以看到,生成质量的提升能跟住 log 尺度的计算量增加,scaling 能力相当不错了。

在这里插入图片描述

总结

DiT 的思路和做法都是比较直接的,就是用 Transformer 替换掉 UNet,并探索了其 scaling 能力和条件注入的结构。

在 Diffusion 成功之后,以 DiT 为代表的,有很多用 Transformer 替换 UNet 的工作,但都没有受到很大的关注。究其原因,应该是 UNet 已经做的足够好,计算开销也低,使用 Transformer scaling 上去,虽然生成质量有所提升,但也没那么显著。

然而,在最近 OpenAI 的 sora 炸裂登场之后,借助 Transformer 实现了任意长度任意尺寸的视频生成,最大长度更是达到了惊人的 60 秒,对比一众在四五秒挣扎的现有模型来说,真算得上是降维打击了。我们或许应该好好想想如何更好地利用起 Transformer 的特点了。

### Diffusion Transformer (DiT) 概述 Diffusion Transformer 是一种创新性的扩散模型,融合了去噪扩散概率模型(DDPM)和Transformer架构的特点[^1]。这种组合使得 DiT 不仅继承了传统扩散模型的强大生成能力,还具备了 Transformer 架构处理序列数据的优势。 ### 工作原理 #### 扩散过程 在传统的扩散模型中,图像或其他形式的数据通过一系列逐步增加噪声的过程被破坏,最终变成纯噪声。而在反向过程中,则是从纯噪声逐渐恢复到原始数据的状态。这一正向和反向过程构成了扩散模型的核心机制[^3]。 对于 Diffusion Transformer 来说,在前向传播阶段会按照一定的时间步长 t 向输入加入高斯白噪音;到了逆向重建环节则依赖于训练好的网络预测每一步应该去除多少噪音来逼近原图特征分布。此过程可以表示为: \[ q(\mathbf{x}_t|\mathbf{x}_{t-1})=\mathcal{N}(\sqrt{1-\beta_t}\mathbf{x}_{t-1};\mathbf{0},\beta_t \mathbf{I}) \] 其中 \(q\) 表示真实数据的概率密度函数,\(β_t\) 控制着每次迭代时所引入的新随机成分的比例大小。 #### Transformer 结合 为了更好地捕捉长期依赖关系并提高建模效率,Diffusion Transformer 将经典的自注意力机制融入进来。具体而言,通过对不同时间戳下的状态进行编码解码操作,实现了对整个演变路径的有效学习与模拟。此外,采用 Patchify 技术将图片切分成多个小块作为 token 输入给 transformer 层进一步增强了局部细节的表现力。 ```python class DiTBlock(nn.Module): def __init__(self, dim, num_heads=8, mlp_ratio=4., drop_path_rate=0.): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = Attention(dim, num_heads=num_heads) # MLP block follows the attention layer. hidden_features = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=hidden_features, out_features=dim, act_layer=nn.GELU, drop=drop_path_rate) def forward(self, x): shortcut = x x = self.norm1(x) x = self.attn(x) x += shortcut x = x + self.mlp(self.norm2(x)) return x ``` 这段代码展示了如何实现一个基本的 DiT Block,它包含了标准化层、多头注意模块以及一个多层感知机组成的MLP结构用于增强表达能力。 ### 应用场景 由于其独特的设计思路,Diffusion Transformer 可广泛应用于多种领域内的复杂任务之中,比如但不限于自然语言处理中的文本摘要生成、机器翻译;计算机视觉方面的人脸识别、风格迁移等。相较于其他类型的生成对抗网络GANs 或者变分自动编码VAEs ,DiTs 更擅长解决那些需要精确控制输出质量的任务,并且能够提供更加稳定可靠的性能表现[^2]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值