Cross-Attention Makes Inference Cumbersome in Text-to-Image Diffusion Models
本研究探讨了文本条件扩散模型中交叉注意在推理过程中的作用。我们发现 交叉注意输出经过几个推理步骤后会收敛到一个固定点 。收敛的时间点自然将整个推理过程分为两个阶段:
- 初始语义规划阶段: 在此阶段,模型依赖于交叉关注规划面向文本的视觉语义,以及随后的
- 保真度改进阶: 在此阶段,模型尝试从先前规划的语义生成图像。
作者发现在保真度改进阶段忽略文本条件不仅降低了计算复杂度,而且保持了模型的性能。这产生了一种简单且无需训练的方法,称为TGATE,用于高效生成,它在交叉注意力输出收敛时缓存它,并在剩余的推理步骤中保持固定。我们对MS-COCO验证集的实证研究证实了其有效性。
1. Introduction
一些研究强调了交叉注意对空间控制的重要性(Prompt-to-Prompt, Atten-and-Excite, Boxdiff),但很少(如果有的话)从去噪过程中的时间角度研究其作用。
在这里,我们提出了一个新的问题:“在文本到图像扩散模型的推理过程中,交叉注意力对每一步都是必要的吗?”
为此,我们研究了在每个推理步骤中交叉关注对生成图像质量的影响。我们的发现突出了两个反直觉的观点:
-
在最初的几个步骤中,交叉注意输出收敛到一个固定点。(收敛时间点将扩散模型去噪过程分为两个阶段:)
- 初始阶段,模型依靠交叉注意规划面向文本的视觉语义, 我们将其表示为 语义规划阶段
- 后续阶段,模型学习从先前的语义规划中生成图像,我们称之为保真度提升阶段。
-
交叉注意在保真度提高阶段是多余的。
- 在语义规划阶段,交叉注意对产生有意义的语义起着至关重要的作用。然而,在后期阶段,交叉注意收敛,对生成过程的影响较小。
- 事实上,在保真度提高阶段绕过交叉注意可以在保持图像生成质量的同时潜在地减少计算成本。
因为交叉注意中的缩放点积是一个二次复杂度的运算。随着现代模型中分辨率和令牌长度的不断增加,交叉注意不可避免地会导致昂贵的计算,成为移动设备等应用程序的重要延迟来源。
这一发现促使我们重新评估交叉注意的作用,并启发我们设计一种简单、有效、无需训练的方法,即暂时控制交叉注意(temporally gating the cross-attention (TGATE)
),以提高效率并保持现成扩散模型的生成质量。
💡 需要注意的是:
- TGATE不会导致性能下降,**因为交叉注意的结果是聚合和冗余的。**事实上,观察到在基线上的初始化距离(FID)略有改善。
- TGATE可以在每张图像上减少65T多次累积操作(Multiple-Accumulate Operations,MACs),并在保真度提高阶段减小0.5B个参数,与基线模型(SDXL)相比,在没有训练成本的情况下,延迟减少了约50%。
2. Temporal Analysis of Cross-Attention
2.1 Cross-Attention.
UNet 中的交叉注意数学定义如下:
C c t = Softmax ( Q z t ⋅ K c d ) ⋅ V c \mathbf{C}_c^t=\text{Softmax}(\frac{Q_z^t\cdot K_c}{\sqrt{d}})\cdot V_c Cct=Softmax(dQzt⋅Kc)⋅Vc
其中, Q z t Q_z^t Qzt 是 z t z_t zt 的投影, K c K_c Kc 和 V c V_c Vc 是文本嵌入 c c c 的投影, d d d 是 K c K_c Kc的特征维度。交叉注意的二次型的计算复杂度,是在处理高分辨率特征时的一个重要瓶颈。
2.2 On the Convergence of Cross-Attention Map
考虑一个问题🤔:考虑到每个时间步骤中噪声输入的变化,交叉注意生成的特征图是表现出时间稳定性还是随时间波动?
作者分析的做法步骤如下:
- 从MS-COCO数据集中随机收集了1000个标题,并使用CFG预训练的SD-2.1模型1生成图像。
- 在推理过程中,我们计算 C t C^t Ct 和 C t + 1 C^{t+1} Ct+1 之间的L2距离,其中 C t C^t Ct 表示时间步长t的交叉注意图。
- 通过平均所有输入标题、条件和深度之间的L2距离来获得两步之间的交叉注意差。
图2显示了跨不同推理步骤的交叉注意差异的变化。一个明显的趋势出现了,表明差异逐渐趋近于零。 收敛性总是在5 ~ 10个推理步骤内出现。因此,交叉注意映射会收敛到一个固定点,不再为图像生成提供动态指导。这一发现从交叉注意的角度支持了CFG的有效性,表明尽管条件和初始噪声不同,无条件批次和有条件批次可以收敛到一个一致的结果。
图中的每个数据点是模型中1000个标题和所有交叉注意图的平均值。阴影区域表示方差,而曲线表示连续步骤之间的差逐渐接近于零。
因此我们发现:这一现象说明了交叉注意在推理过程中的影响是不均匀的,并启发了下一节交叉注意的时间分析。
2.3 The Role of Cross-Attention in Inference
Analytical Tool.
我们通过在特定阶段有效地“去除”交叉关注并观察由此产生的生成质量差异来衡量交叉关注的影响。在实践中,这种移除近似于用空文本的占位符替换原始文本嵌入,即“”。我们将标准去噪轨迹形式化为一个序列:
为了简化起见,省略了时间步骤索引 t t t 和引导比例系数 w w w。 从序列 S \boldsymbol{S} S 中生成的图像表示为 x x x, 我们然后修改标准序列,将条件文本嵌入 x x x 替换为空文本嵌入 ∅ \varnothing ∅,在指定的推理区间内,产生两个新的序列, S m F \boldsymbol{S}^F_m SmF 和 S m L \boldsymbol{S}^L_m SmL ,基于一个标量 m m m
这里,
m
m
m 是作为一个 gate step
来分割这个两阶段。在序列
S
m
F
\boldsymbol{S}^F_m
SmF 中,我们将无文本嵌入
v
a
r
n
o
t
h
i
n
g
varnothing
varnothing 替换原始文本嵌入,从
m
+
1
m + 1
m+1 到第
n
n
n 步。
相反的,在序列 S m L \boldsymbol{S}^L_m SmL 中,我们将使用空文本嵌入 v a r n o t h i n g varnothing varnothing,从 1 1 1 到第 m m m 步,而 m m m 到第 n n n 步 则维持使用原始文本嵌入。
我们将遵循这两个去噪轨迹生成的图像分别表示为 x m F x^F_m xmF 和 x m L x^L_m xmL。为了评估交叉关注的作用,我们对比了 x x x, x m F x^F_m xmF 和 x m L x^L_m xmL 之间的生成质量。
如果 x x x 和 x m F x^F_m xmF 之间的生成质量存在显著差异,则表明该阶段交叉注意的重要性。相反,如果没有实质性的变化,则该阶段交叉注意可能不是必须的。
补充:我们使用SD-2.1作为基本模型,并使用DPM求解器(Lu et al., 2022)进行噪声调度。所有实验的推理步长都设为25。文字提示“宇航员在太空骑马的高质量照片。用于可视化。
Results and Discussion.
图3(a)中给出了预测噪声均值的轨迹,经验表明,经过25个推理步骤后,去噪过程收敛。因此,在这个区间内分析交叉注意的影响就足够了。如图3(b)所示,gate step
m
m
m 设置为10,这将产生三个轨迹:
S
,
S
m
F
,
S
m
L
\boldsymbol{S}, \boldsymbol{S_m^F}, \boldsymbol{S_m^L}
S,SmF,SmL
结果表明10步后的交叉注意不影响最终结果。 但是如果在最初的步骤中忽略交叉注意会导致显著的差异。如图3©所示,这种消除导致MS-COCO验证集中的生成质量(FID)显著下降,甚至比不使用CFG生成图像的弱基线更差。
个人感觉 S m L \boldsymbol{S_m^L} SmL 效果不好也正常, 因为 S m L \boldsymbol{S_m^L} SmL 也太极端了,一开始全部置为空。
然后作者对 gate step
m
m
m 分别取
{
3
,
5
,
10
}
\{3,5,10\}
{3,5,10} 做了进一步的实验,当
m
m
m 大于 5 时,忽略交叉注意的模型可以获得更好的 FID。
为了进一步证明我们的发现的普遍性,我们在各种条件下进行了实验,包括一系列的总推断数、噪声调度器和基本模型。如表2、3和4所示,我们报告了 S , S m F , S m L \boldsymbol{S}, \boldsymbol{S_m^F}, \boldsymbol{S_m^L} S,SmF,SmL在MS-COCO验证集上的FID。实验结果一致表明, S m F \boldsymbol{S_m^F} SmF 的FIDs略优于基线S,并且远远优于 S m L \boldsymbol{S_m^L} SmL。这些研究强调了这些发现具有广泛适用性的潜力。
3. Temporally Gating Cross-Attention(TGATE)
前文可以发现最后的推理步骤中交叉注意计算是多余的。 所以考虑在不重新训练模型的情况下删除/替换交叉注意是很重要的。受DeepCache的启发,作者提出了一种有效且无需训练的方法,称为TGATE。该方法缓存语义规划阶段的注意力结果,并在整个保真度改进阶段重用它们。
Caching Cross-Attention Maps.
gate step
为
m
m
m, 对于第
m
m
m 步,对于第
i
i
i 个交叉注意模块,可以通过基于CFG的推理得到
C
m
,
i
c
C_{m,i}^c
Cm,ic 和
C
m
,
i
∅
C_{m,i}^{\varnothing}
Cm,i∅ 两个交叉注意映射。我们计算这两个映射的平均值作为锚点,并将其存储在先进先出特征缓存
F
F
F 中。遍历所有交叉注意块后,
F
F
F 可以写成:
F = { 1 2 ( C ∅ m , i + C c m , i ) ∣ i ∈ [ 1 , l ] } , \mathbf{F}=\{\frac12(\mathbf{C}_\varnothing^{m,i}+\mathbf{C}_c^{m,i})|i\in[1,l]\}, F={21(C∅m,i+Ccm,i)∣i∈[1,l]},
其中, l l l 表示交叉注意模块的总数,在 SD-2.1中, l = 16 l = 16 l=16.
Re-using Cached Cross-Attention Maps.
在保真度改进阶段的每一步中,当向前传递过程中遇到交叉注意操作时,将其从计算图中省略。相反,缓存的 F . p o p ( 0 ) \boldsymbol{F}.pop(0) F.pop(0) 被输入到后续的计算中。这种方法不会在每个时间步产生相同的预测,因为 UNet 中有残差连接允许模型跳过交叉注意。
4. Experiments
四个基准模型
- Stable Diffusion-1.5 (SD-1.5)
- SD-2.1
- SDXL
- PixArt-Alpha(这个是基于 Transformer)
加速的基线模型
- Latent Consistency Model
- Adaptive Guidance
- DeepCache
- Multiple Noise Schedules
为了进行令人信服的实证研究,我们将我们的方法与几种加速基线方法进行了比较:潜在一致性模型(Luo等人,2023)、自适应制导(Castillo等人,2023)、DeepCache (Ma等人,2023)和多个噪声调度程序(Karras等人,2022;Lu et al., 2022;Song等人,2020)。注意,我们的方法与现有的加速去噪推理方法是正交的;因此,我们的方法可以简单集成,进一步加快这一过程。
实验结果