紧接上回:【动手深度学习-笔记】注意力机制(一)注意力机制框架
注意力评分函数
回顾使用高斯核的Nadaraya-Watson 核回归:
f
(
x
)
=
∑
i
=
1
n
α
(
x
,
x
i
)
y
i
=
∑
i
=
1
n
exp
(
−
1
2
(
x
−
x
i
)
2
)
∑
j
=
1
n
exp
(
−
1
2
(
x
−
x
j
)
2
)
y
i
=
∑
i
=
1
n
s
o
f
t
m
a
x
(
−
1
2
(
x
−
x
i
)
2
)
y
i
.
(1)
\begin{split}\begin{aligned} f(x) &=\sum_{i=1}^n \alpha(x, x_i) y_i\\ &= \sum_{i=1}^n \frac{\exp\left(-\frac{1}{2}(x - x_i)^2\right)}{\sum_{j=1}^n \exp\left(-\frac{1}{2}(x - x_j)^2\right)} y_i \\&= \sum_{i=1}^n \mathrm{softmax}\left(-\frac{1}{2}(x - x_i)^2\right) y_i. \end{aligned}\end{split}\tag{1}
f(x)=i=1∑nα(x,xi)yi=i=1∑n∑j=1nexp(−21(x−xj)2)exp(−21(x−xi)2)yi=i=1∑nsoftmax(−21(x−xi)2)yi.(1)
我们将高斯核指数部分
−
1
2
(
x
−
x
i
)
2
-\frac{1}{2}(x - x_i)^2
−21(x−xi)2视为注意力评分函数(attention scoring function), 简称评分函数(scoring function), 然后把这个函数的输出结果输入到softmax函数中进行运算。 通过上述步骤,我们将得到与键对应的值的概率分布(即注意力权重)。 最后,注意力汇聚的输出就是基于这些注意力权重的值的加权和。
至此,我们可以根据上述的步骤,将注意力机制框架进一步细化描述:
用严格的数学语言描述,对于一个查询
q
∈
R
q
\mathbf{q} \in \mathbb{R}^q
q∈Rq和m个键值对
(
k
1
,
v
1
)
,
…
,
(
k
m
,
v
m
)
,
k
i
∈
R
k
,
v
i
∈
R
v
(\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m), \mathbf{k}_i \in \mathbb{R}^k,\mathbf{v}_i \in \mathbb{R}^v
(k1,v1),…,(km,vm),ki∈Rk,vi∈Rv。注意力汇聚函数
f
f
f可表示为:
f
(
q
,
(
k
1
,
v
1
)
,
…
,
(
k
m
,
v
m
)
)
=
∑
i
=
1
m
α
(
q
,
k
i
)
v
i
∈
R
v
,
f(\mathbf{q}, (\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)) = \sum_{i=1}^m \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i \in \mathbb{R}^v,
f(q,(k1,v1),…,(km,vm))=i=1∑mα(q,ki)vi∈Rv,
其中
α
(
q
,
k
i
)
\alpha(\mathbf{q}, \mathbf{k}_i)
α(q,ki)为注意力权重, 是将
q
,
k
i
\mathbf{q}, \mathbf{k}_i
q,ki通过注意力评分函数
a
a
a得到一个相似性度量(标量),再通过softmax运算得到的概率分布:
α
(
q
,
k
i
)
=
s
o
f
t
m
a
x
(
a
(
q
,
k
i
)
)
=
exp
(
a
(
q
,
k
i
)
)
∑
j
=
1
m
exp
(
a
(
q
,
k
j
)
)
∈
R
.
\alpha(\mathbf{q}, \mathbf{k}_i) = \mathrm{softmax}(a(\mathbf{q}, \mathbf{k}_i)) = \frac{\exp(a(\mathbf{q}, \mathbf{k}_i))}{\sum_{j=1}^m \exp(a(\mathbf{q}, \mathbf{k}_j))} \in \mathbb{R}.
α(q,ki)=softmax(a(q,ki))=∑j=1mexp(a(q,kj))exp(a(q,ki))∈R.
选择不同的注意力评分函数 a a a会导致不同的注意力汇聚操作。 流行的注意力评分函数有加性注意力(additive attention)评分和缩放点积注意力(scaled dot-product attention)评分
加性注意力
当查询
q
\mathbf{q}
q和键
k
\mathbf{k}
k的长度不同时,可以使用加性注意力评分函数:
a
(
q
,
k
)
=
v
⊤
tanh
(
W
q
q
+
W
k
k
)
∈
R
(2)
a(\mathbf q, \mathbf k) = \mathbf v^\top \text{tanh}(\mathbf W_q\mathbf q + \mathbf W_k \mathbf k) \in \mathbb{R}\tag{2}
a(q,k)=v⊤tanh(Wqq+Wkk)∈R(2)
其中输入为查询
q
∈
R
q
\mathbf{q} \in \mathbb{R}^q
q∈Rq和键
k
∈
R
k
\mathbf{k} \in \mathbb{R}^k
k∈Rk;
分别和两个权重矩阵
W
q
∈
R
h
×
q
,
W
k
∈
R
h
×
k
\mathbf W_q\in\mathbb R^{h\times q},\mathbf W_k\in\mathbb R^{h\times k}
Wq∈Rh×q,Wk∈Rh×k相乘并相加,得到长为
h
h
h的列向量;
使用
tanh
\tanh
tanh作为激活函数,最后和值向量
v
∈
R
h
\mathbf v\in\mathbb R^{h}
v∈Rh的转置相乘,得到一个标量值。
相当于将查询和键连结起来后输入到一个单隐藏层感知机(MLP)中,其隐藏单元数 h h h是一个超参数
缩放点积注意力
查询和键具有相同的长度的情况下,我们可以使用缩放点积注意力评分来提高计算效率:
a
(
q
,
k
)
=
q
⊤
k
/
d
.
a(\mathbf q, \mathbf k) = \mathbf{q}^\top \mathbf{k} /\sqrt{d}.
a(q,k)=q⊤k/d.
实践中往往是以小批次进行计算,假设查询和键的长度
d
d
d,值的长度为
v
v
v,一个批次大小为
n
n
n,键的数量为
m
m
m则有:
s
o
f
t
m
a
x
(
Q
K
⊤
d
)
V
∈
R
n
×
v
.
\mathrm{softmax}\left(\frac{\mathbf Q \mathbf K^\top }{\sqrt{d}}\right) \mathbf V \in \mathbb{R}^{n\times v}.
softmax(dQK⊤)V∈Rn×v.
其中
Q
∈
R
n
×
d
,
K
∈
R
m
×
d
,
V
∈
R
m
×
v
\mathbf Q\in\mathbb R^{n\times d},\mathbf K\in\mathbb R^{m\times d},\mathbf V\in\mathbb R^{m\times v}
Q∈Rn×d,K∈Rm×d,V∈Rm×v