详解加性注意力和缩放点积注意力
注意力机制
注意力机制的本质是加权求和,评分函数a
计算查询与键之间的注意力分数,经过softmax
计算得到注意力权重,权重与值进行加权求和得到注意力输出。
用数学语言描述,假设有一个查询 𝐪∈ℝ𝑞和 𝑚个“键-值”对 (𝐤1,𝐯1),…,(𝐤𝑚,𝐯𝑚), 其中𝐤𝑖∈ℝ𝑘,𝐯𝑖∈ℝ𝑣。 注意力汇聚函数𝑓就被表示成值的加权和:
f
(
q
,
(
k
1
,
v
1
)
,
…
,
(
k
m
,
v
m
)
)
=
∑
i
=
1
m
α
(
q
,
k
i
)
v
i
∈
R
v
f(\mathbf{q}, (\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)) = \sum_{i=1}^m \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i \in \mathbb{R}^v
f(q,(k1,v1),…,(km,vm))=i=1∑mα(q,ki)vi∈Rv
其中查询𝐪和键𝐤𝑖的注意力权重(标量)是通过注意力评分函数𝑎
将两个向量映射成标量, 再经过softmax
运算得到的:
α
(
q
,
k
i
)
=
s
o
f
t
m
a
x
(
a
(
q
,
k
i
)
)
=
exp
(
a
(
q
,
k
i
)
)
∑
j
=
1
m
exp
(
a
(
q
,
k
j
)
)
∈
R
\alpha(\mathbf{q}, \mathbf{k}_i) = \mathrm{softmax}(a(\mathbf{q}, \mathbf{k}_i)) = \frac{\exp(a(\mathbf{q}, \mathbf{k}_i))}{\sum_{j=1}^m \exp(a(\mathbf{q}, \mathbf{k}_j))} \in \mathbb{R}
α(q,ki)=softmax(a(q,ki))=∑j=1mexp(a(q,kj))exp(a(q,ki))∈R
选择不同的注意力评分函数𝑎
会导致不同的注意力汇聚操作,主流的评分函数有:加性和缩放点积两种,下面分别介绍。
加性注意力
给定查询𝐪∈ℝ𝑞和 键𝐤∈ℝ𝑘, 加性注意力(additive attention)的评分函数为:
a
(
q
,
k
)
=
w
v
⊤
tanh
(
W
q
q
+
W
k
k
)
∈
R
a(\mathbf q, \mathbf k) = \mathbf w_v^\top \text{tanh}(\mathbf W_q\mathbf q + \mathbf W_k \mathbf k) \in \mathbb{R}
a(q,k)=wv⊤tanh(Wqq+Wkk)∈R
其中可学习的参数是𝐖𝑞∈ℝℎ×𝑞、 𝐖𝑘∈ℝℎ×𝑘和 𝐰𝑣∈ℝℎ。 加性注意力评分函数可以看作,将查询和键连结起来后输入到一个多层感知机(MLP)中, 感知机包含一个隐藏层,其隐藏单元数是一个超参数ℎ。 通过使用tanh作为激活函数,并且禁用偏置项,如下图所示:
缩放点积注意力
缩放点积注意力(scaled dot-product attention)评分函数为:
a
(
q
,
k
)
=
q
⊤
k
/
d
a(\mathbf q, \mathbf k) = \mathbf{q}^\top \mathbf{k} /\sqrt{d}
a(q,k)=q⊤k/d
其中q,k均为d维向量。在实践中,我们通常从小批量的角度来考虑提高效率, 例如基于𝑛个查询和𝑚个键-值对计算注意力, 其中查询和键的长度为𝑑,值的长度为𝑣。 查询𝐐∈ℝ𝑛×𝑑、 键𝐊∈ℝ𝑚×𝑑和 值𝐕∈ℝ𝑚×𝑣的缩放点积注意力是:
s
o
f
t
m
a
x
(
Q
K
⊤
d
)
V
∈
R
n
×
v
\mathrm{softmax}\left(\frac{\mathbf Q \mathbf K^\top }{\sqrt{d}}\right) \mathbf V \in \mathbb{R}^{n\times v}
softmax(dQK⊤)V∈Rn×v
假设查询和键的所有元素都是独立的随机变量, 并且都满足零均值和单位方差, 那么两个向量的点积的均值为0,方差为d。 为确保无论向量长度如何, 点积的方差在不考虑向量长度的情况下仍然是1, 将点积除以𝑑^0.5。
总结
- 加性注意力和缩放点积注意力计算复杂度接近,但矩阵乘法有非常成熟的加速实现,所以缩放点积注意力的计算效率更高。
- 在d(注意力矩阵的维度)较小时,加性和缩放点积注意力效果接近,但随着d的增大,加性注意力开始显著超越缩放点积。原因是极大的点积值将整个 softmax 推向梯度平缓区,使得收敛困难,所以缩放点积注意力需要除以d^0.5。