扩散模型和Transformer梦幻联动!替换U-Net,一举拿下新SOTA!

点击下方卡片,关注“CVer”公众号

AI/CV重磅干货,第一时间送达

点击进入—>扩散模型微信技术交流群

转载自:量子位

“U-Net已死,Transformer成为扩散模型新SOTA了!”

就在ChatGPT占尽AI圈风头时,纽约大学谢赛宁的图像生成模型新论文横空出世,收获一众同行惊讶的声音。

6db0dd67c805f0bca7eec0a90c957c12.png
MILA在读ML博士生Ethan Caballero

论文创意性地将Transformer与扩散模型融合,在计算效率和生成效果上均超越了基于U-Net的经典模型ADM和LDM,打破了U-Net统治扩散模型的“普遍认知”。

f0c0de1138f54e7555fa2801d21c9fc7.png

网友给这对新组合命名也是脑洞大开:

All we need is U-Transformer

希望他们没有错过Transffusion这个名字。

b2eb88dbd3eda292fb1d95f4412e6971.png

要知道,这几年虽然Transformer占尽风头,但U-Net在扩散模型领域仍然一枝独秀——

无论是“前任王者”DALL·E2还是“新晋生成AI”Stable Diffusion,都没有使用Transformer作为图像生成架构。

b53820c96dca0a7f255bf04152c38cac.png
英伟达AI科学家Jim Fan

如今新研究表明,U-Net并非不可用Transformer替代。

“U-Net并非不可替代”

论文提出的新架构名叫Diffusion Transformers(DiTs)。

架构保留了很多ViT的特性,其中整体架构如图左(包含多个DiT模块),具体的DiT模块组成如图右:

656cd4140b451e2f2ada1ac6e7de1a81.png

更右边的两个灰色框的模块,则是DiT架构的“变体”。主要是探讨在条件输入下,不同的架构是否能对信息进行更好的处理,包括交叉注意力等。

最终结果表明,还是层归一化(Layer Normalization)更好用,这里最终选用了Adaptive Layer Normalization(自适应层归一化)的方法。

对于这篇论文研究的目的,作者表示希望探讨扩散模型中不同架构选择的重要性,以及也是给将来生成模型的评估做一个评判标准。

先说结果——作者认为,U-Net的归纳偏置(inductive bias),对于扩散模型性能提升不是必须的。

与之相反,他们能“轻松地”(readily)被Transformer的标准架构取代。

311a6c45114fa7a7dd64b8dea9e13017.png

有网友发现,DALL·E和DALL·E2似乎都有用到Transformer。

这篇论文和它们的差异究竟在哪里?

事实上,DALL·E虽然是Transformer,但并非扩散模型,本质是基于VQVAE架构实现的;

3fbc868d14708ac0dcc279726edb487c.png

至于DALL·E2和Stable Diffusion,虽然都分别将Transformer用在了CLIP和文本编码器上,但关键的图像生成用的还是U-Net。

64c0d961b3d8b668db087c33d2b564e3.png
经典U-Net架构

不过,DiT还不是一个文本生成图像模型——目前只能基于训练标签生成对应的新图像。

虽然生成的图片还带着股“ImageNet风”,不过英伟达AI科学家Jim Fan认为,将它改造成想要的风格和加上文本生成功能,都不是难点。

如果将标签输入调整成其他向量、乃至于文本嵌入,就能很快地将DiT改造成一个文生图模型:

Stable-DiT马上就要来了!

a109b6dc8ddf9143177eca03b593c07f.png

所以DiTs在生成效果和运算速率上,相比其他图像生成模型究竟如何?

在ImageNet基准上取得SOTA

为了验证DiTs的最终效果,研究者将DiTs沿“模型大小”和“输入标记数量”两个轴进行了缩放。

具体来说,他们尝试了四种不同模型深度和宽度的配置:DiT-S、DiT-B、DiT-L和DiT-XL,在此基础上又分别训练了3个潜块大小为8、4和2的模型,总共是12个模型。

2a6ed8bb42ce139a31305eb7d0460840.png

从FID测量结果可以看出,就像其他领域一样,增加模型大小和减少输入标记数量可以大大提高DiT的性能。

FID是计算真实图像和生成图像的特征向量之间距离的一种度量,越小越好。

换句话说,较大的DiTs模型相对于较小的模型是计算效率高的,而且较大的模型比较小的模型需要更少的训练计算来达到给定的FID。

其中,Gflop最高的模型是DiT-XL/2,它使用最大的XL配置,patch大小为2,当训练时间足够长时,DiT-XL/2就是里面的最佳模型。

c9acc2dbc389f294d4d03f5ab3701797.png

于是在接下来,研究人员就专注于DiT-XL/2,他们在ImageNet上训练了两个版本的DiT-XL/2,分辨率分别为256x256和512x512,步骤分别为7M和3M。

当使用无分类器指导时,DiT-XL/2比之前的扩散模型数据都要更好,取得SOTA效果:

在256x256分辨率下,DiT-XL/2将之前由LDM实现的最佳FID-50K从3.60降至了2.27。

并且与基线相比,DiTs模型本身的计算效率也很高:

DiT-XL/2的计算效率为119 Gflops,相比而言LDM-4是103 Gflops,ADM-U则是742 Gflops。

fa20265d98a69e9cbe3ec1050f71a092.png

同样,在512x512分辨率下,DiT-XL/2也将ADM-U之前获得的最佳FID 3.85降至了3.04。

不过此时ADM-U的计算效率是2813 Gflops,而XL/2只有525 Gflops。

64792d9f463f4c31219dd1f9fc6ec461.png

研究作者

本篇论文作者为UC伯克利的William Peebles和纽约大学的谢赛宁。

8829250b14e198d21ca3e8b504f94af1.png

Scalable Diffusion Models with Transformers
论文地址:

https://arxiv.org/abs/2212.09748

代码:https://github.com/facebookresearch/DiT

William Peebles,目前是UC伯克利的四年级博士生,本科毕业于麻省理工学院。研究方向是深度学习和人工智能,重点是深度生成模型。

1272479fee2fb5f538b83c86738bcf1f.png

之前曾在Meta、Adobe、英伟达实习过,这篇论文就是在Meta实习期间完成。

谢赛宁,纽约大学计算机科学系助理教授,之前曾是Meta FAIR研究员,本科就读于上海交通大学ACM班,博士毕业于UC圣迭戈分校。

谢赛宁读博士时曾在FAIR实习,期间与何恺明合作完成ResNeXt,是该论文的一作,之前何恺明一作论文MAE他也有参与。

4790ddbd952cfc2cfaded35df82fa0a8.png

当然,对于这次Transformer的表现,也有研究者们表示“U-Net不服”。

例如三星AI Lab科学家Alexia Jolicoeur-Martineau就表示:

U-Net仍然充满生机,我相信只需要经过细小调整,有人能将它做得比Transformer更好。

看来,图像生成领域很快又要掀起新的“较量风暴”了。

参考链接:
[1]https://twitter.com/ethanCaballero/status/1605621603135471616
[2]https://www.wpeebles.com/DiT
[3]https://paperswithcode.com/paper/scalable-diffusion-models-with-transformers#code

 
 

点击进入—>扩散模型微信技术交流群

DiT论文和代码下载

 
 

后台回复:DiT,即可下载上面论文和代码

目标检测和Transformer交流群成立
扫描下方二维码,或者添加微信:CVer222,即可添加CVer小助手微信,便可申请加入CVer-目标检测或者Transformer 微信交流群。另外其他垂直方向已涵盖:目标检测、图像分割、目标跟踪、人脸检测&识别、OCR、姿态估计、超分辨率、SLAM、医疗影像、Re-ID、GAN、NAS、深度估计、自动驾驶、强化学习、车道线检测、模型剪枝&压缩、去噪、去雾、去雨、风格迁移、遥感图像、行为识别、视频理解、图像融合、图像检索、论文投稿&交流、PyTorch、TensorFlow和Transformer等。
一定要备注:研究方向+地点+学校/公司+昵称(如目标检测或者Transformer+上海+上交+卡卡),根据格式备注,可更快被通过且邀请进群

▲扫码或加微信号: CVer222,进交流群
CVer学术交流群(知识星球)来了!想要了解最新最快最好的CV/DL/ML论文速递、优质开源项目、学习教程和实战训练等资料,欢迎扫描下方二维码,加入CVer学术交流群,已汇集数千人!

▲扫码进群
▲点击上方卡片,关注CVer公众号
整理不易,请点赞和在看
### Transformer与U-Net在CCA(典型相关分析)方面的差异 #### 差异概述 TransformerU-Net在架构设计上存在显著的不同,这直接影响了它们在典型相关分析(Canonical Correlation Analysis, CCA)中的表现。具体来说: 1. **注意力机制 vs 卷积操作** Transformer的核心在于自注意力机制(Self-Attention Mechanism),它能够捕捉全局依赖关系并动态调整权重[^3]。这种特性使得Transformer更适合处理具有复杂空间关联的数据集,在CCA中可以更有效地提取跨通道的相关性。相比之下,U-Net主要依靠卷积层来捕获局部特征,并通过跳跃连接传递多尺度信息[^1]。虽然这种方式对于医学影像分割等任务非常有效,但在高维数据的空间建模方面可能不如Transformer灵活。 2. **特征表示能力** 在CCA的应用场景下,Transformer倾向于生成更加抽象且语义丰富的特征向量,因为其基于位置编码的位置感知能力强大的序列建模功能有助于揭示隐藏模式[^2]。然而,U-Net由于专注于像素级重建以及边缘细节保留等问题,则通常会产生较为具体的低层次视觉描述子。因此当涉及到高层次概念之间的联系挖掘时(如不同模态间的关系),Transformers可能会占据优势地位. 3. **计算资源消耗** 尽管Transformer提供了卓越的性能潜力,但它也带来了较高的内存占用训练时间成本。这是因为全图范围内计算点对点相互作用所需的二次方阶运算复杂度所致。而对于许多实际应用而言(U-Net擅长领域),实时性轻量化可能是优先考虑因素之一;此时采用优化后的版本或者简化版网络结构会成为更好的选择方案. --- ### 应用场景对比 | 特性/方法 | Transformer | U-Net | |------------------|------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------| | 数据类型 | 更适合处理大规模、高维度、非欧几里得结构化数据 (e.g., 自然语言处理). | 主要应用于二维或三维网格状输入信号,例如MRI扫描图像或其他生物医学成像资料. | | 关注重点 | 跨区域长距离交互及整体布局理解 | 局部纹理精细刻画加上多层次融合 | | 计算开销 | 较大 | 相对较小 | | 实际案例举例 | 文本转图片生成(Time模型实现反事实解释)[^1], 多源遥感数据分析 | 图像语义分隔(WET-Unet改进传统Unets ), 细胞核检测 | 综上所述,如果研究目标侧重于探索复杂的内在逻辑规律并通过数学工具加以验证的话,那么选用transformer将会是一个明智之举; 反之若是追求快速部署易于维护的小型项目则推荐利用unet框架构建解决方案。 --- ### 示例代码展示两者差异 以下是两个简单例子分别展示了如何使用PyTorch库定义基础形式下的transformer encoder block unet basic building blocks: ```python import torch.nn as nn from einops.layers.torch import Rearrange class SimpleTransformerBlock(nn.Module): def __init__(self, dim, heads=8, dim_head=64, mlp_dim=512): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads) def forward(self, x): identity = x x = self.norm1(x) attn_output, _ = self.attn(query=x, key=x, value=x) out = identity + attn_output # Residual connection return out class DoubleConv(nn.Module): """Double Convolution Layer used in UNet""" def __init__(self, in_channels, out_channels): super(DoubleConv, self).__init__() self.double_conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same'), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same'), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)) def forward(self, x): return self.double_conv(x) ``` 上述代码片段清晰表明了二者设计理念的区别——前者强调可变长度序列间的相对重要程度评估后者聚焦固定大小窗口内部邻近单元之间的影响传播路径规划。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值