点击下方卡片,关注“CVer”公众号
AI/CV重磅干货,第一时间送达
添加微信:CVer5555,小助手会拉你进群!
扫描下方二维码,加入CVer学术星球!可以获得最新顶会/顶刊上的论文idea和CV从入门到精通资料,及最前沿应用!发论文/搞科研/涨薪,强烈推荐!
论文标题:Cross-Attention Makes Inference Cumbersome in Text-to-Image Diffusion Models
开源代码:https://github.com/HaozheLiu-ST/T-GATE
单位:阿卜杜拉国王科技大学(KAUST),新加坡国立大学(NUS),IDSIA
论文探讨了文本条件扩散模型的交叉注意力(Cross-Attention)模块在推理过程中的作用。研究发现在扩散模型的推断去噪步骤中,交叉注意力模块的输出会收敛到一个固定点。因此,整个推理过程可以分为两个阶段:1. 初始的语义规划阶段,在此阶段模型根据文本信息来设计视觉语义信息;2. 随后的图像保真度提升阶段,模型会根据第一阶段设计的语义信息生成高质量的图像。令人惊讶的是,在推理的第二阶段中,忽略给定的文本条件不仅可以降低模型计算复杂度,也会稍微降低FID分数。根据这种现象,研究团队提出了一种简单且无需训练的高效生成推理方法TGATE,该方法在交叉注意力模块收敛后对其输出进行缓存,并在剩余推理步骤中替换交叉注意力模块。
1. 动机
扩散模型已被广泛用于图像生成。它们使用交叉注意力(Cross-Attention)将不同模态数据进行对齐,例如用于条件生成任务的文本信息。一些研究强调了交叉注意力对空间控制的重要性,但很少有研究从时间角度研究其在去噪过程中的作用。另外,交叉注意力模块中的缩放点积是一种具有$O(n^2)$复杂度的运算。随着现代模型中图像分辨率和token长度的不断增加,交叉注意力模块会带来极其昂贵的计算成本,也会成为移动设备等终端中延迟的重要来源。这促使研究团队重新评估交叉注意力模块的作用。这篇文章探讨了一个新问题:“在文本到图像扩散模型的推理过程中,交叉注意力对每一步都至关重要吗?”。对于这个问题的研究总结如下:
交叉注意力模块在推理过程中会提前收敛。收敛的时间点(time step)将扩散模型的去噪过程分为两个阶段:i)初始阶段,在此阶段,模型依赖于交叉注意力模块来设计面向文本的视觉语义;本文将其表示为语义规划阶段,以及ii)后续阶段,模型学习从先前的语义规划生成图像,即图像保真度提升阶段。
在图像保真度提升阶段,交叉注意力模块是多余的。在语义规划阶段,交叉注意力在创建有意义的语义方面发挥着重要作用。然而,在第二阶段,交叉注意力模块已经完成收敛,它的输出对生成过程的影响很小。事实上,在图像保真度提升阶段绕过交叉注意力模块不仅可以降低计算成本,也可以保持较高的图像生成质量。
2. 贡献
该研究团队设计一种简单、有效且无需训练的方法,即时间门控交叉注意力(TGATE),以提高模型推断效率并保持现有扩散模型的生成质量。该方法在交叉注意力模块收敛后对其输出进行缓存,并在剩余推理步骤中替换交叉注意力模块。主要贡献如下:
TGATE通过在交叉注意力模块收敛后缓存和复用交叉注意力的结果来提高效率,消除冗余的交叉注意力。生成一张图像可以减少65T MACs,并减少推断模型0.5B的参数量,与baseline模型SD-XL相比,图像生成时间减少了约50%。
TGATE不会导致性能下降,由于交叉注意力模块是收敛的和冗余的。在一些情况下生成图像的FID指标比baseline模型甚至略有降低。
TGATE支持扩散模型中,基于CNN的U-Net模型,Transformer模型,以及Consistency Model。也可以结合其他优化算法例如DeepCache,实现更快的模型推断。
3. 实验评估
该论文在主流的预训练扩散模型,包括SD-1.5, SD-2.1, SD-XL和PixArt-Alpha上,使用Multiple-Accumulate Operations (MACs), 参数量(paramters),推断时间(latency)和 MS-COCO-10k数据集上的zero-shot FID四个指标,验证所提出的TGATE方法。实验结果如下
本文也在现有扩散模型加速方法DeepCache和Latent Consistency Model进行验证。
4. 可视化结果
何恺明在MIT授课的课件PPT下载
在CVer公众号后台回复:何恺明,即可下载本课程的所有566页课件PPT!赶紧学起来!
CVPR 2024 论文和代码下载
在CVer公众号后台回复:CVPR2024,即可下载CVPR 2024论文和代码开源的论文合集
Mamba和扩散模型交流群成立
扫描下方二维码,或者添加微信:CVer5555,即可添加CVer小助手微信,便可申请加入CVer-Mamba和扩散模型微信交流群。另外其他垂直方向已涵盖:目标检测、图像分割、目标跟踪、人脸检测&识别、OCR、姿态估计、超分辨率、SLAM、医疗影像、Re-ID、GAN、NAS、深度估计、自动驾驶、强化学习、车道线检测、模型剪枝&压缩、去噪、去雾、去雨、风格迁移、遥感图像、行为识别、视频理解、图像融合、图像检索、论文投稿&交流、PyTorch、TensorFlow和Transformer、NeRF、3DGS、Mamba等。
一定要备注:研究方向+地点+学校/公司+昵称(如Mamba或者扩散模型+上海+上交+卡卡),根据格式备注,可更快被通过且邀请进群
▲扫码或加微信号: CVer5555,进交流群
CVer计算机视觉(知识星球)来了!想要了解最新最快最好的CV/DL/AI论文速递、优质实战项目、AI行业前沿、从入门到精通学习教程等资料,欢迎扫描下方二维码,加入CVer计算机视觉(知识星球),已汇集近万人!
▲扫码加入星球学习
▲点击上方卡片,关注CVer公众号
整理不易,请点赞和在看