我们知道BEVFormer基于Transformer,所以要想很好的理解BEVFormer,先要理解Transformer,在这篇文章里,我们先讲解一下Transformer的基本概念。
1. 自注意力机制
Transformer中最重要的一个概念之一就是自注意力机制。
1.1. 输入信号定义
Transformer的输入张量是一个张量序列,其中第i个元素为:
x
(
i
)
∈
R
n
\boldsymbol{x}^{(i)} \in R^{n}
x(i)∈Rn,我们为了讨论方便,假设输入序列只有4个张量组成:
x
(
1
)
,
x
(
2
)
,
x
(
3
)
,
x
(
4
)
\boldsymbol{x}^{(1)},\boldsymbol{x}^{(2)}, \boldsymbol{x}^{(3)}, \boldsymbol{x}^{(4)}
x(1),x(2),x(3),x(4)
1.2. 定义权值矩阵
我们定义网络需要学习的权值矩阵:
- W K ∈ R d m × d k W^{K} \in R^{d_{m} \times d_{k} } WK∈Rdm×dk,对应于每个输入张量的键Key;
- W V ∈ R d m × d v W^{V} \in R^{d_{m} \times d_{v} } WV∈Rdm×dv,对应于每个输入张量的值,我们可以理解为一个输入张量由键、值对表示;
- W Q ∈ R d m × d q W^{Q} \in R^{d_{m} \times d_{q} } WQ∈Rdm×dq,对应于每个输入张量要计算它与其他输入张量关联度数值的查询条件(在所有输入张量的键K中查找);
1.3. 定义输入张量的键
k ( 1 ) = ( W K ) T ⋅ x ( 1 ) ∈ R d k k ( 2 ) = ( W K ) T ⋅ x ( 2 ) ∈ R d k k ( 3 ) = ( W K ) T ⋅ x ( 3 ) ∈ R d k k ( 4 ) = ( W K ) T ⋅ x ( 4 ) ∈ R d k \boldsymbol{k}^{(1)} = (W^{K})^{T} \cdot \boldsymbol{x}^{(1)} \quad \in R^{d_{k} } \\ \boldsymbol{k}^{(2)} = (W^{K})^{T} \cdot \boldsymbol{x}^{(2)} \quad \in R^{d_{k} } \\ \boldsymbol{k}^{(3)} = (W^{K})^{T} \cdot \boldsymbol{x}^{(3)} \quad \in R^{d_{k} } \\ \boldsymbol{k}^{(4)} = (W^{K})^{T} \cdot \boldsymbol{x}^{(4)} \quad \in R^{d_{k} } \\ k(1)=(WK)T⋅x(1)∈Rdkk(2)=(WK)T⋅x(2)∈Rdkk(3)=(WK)T⋅x(3)∈Rdkk(4)=(WK)T⋅x(4)∈Rdk
1.4. 定义输入张量的值
v ( 1 ) = ( W Q ) T ⋅ x ( 1 ) ∈ R d v v ( 2 ) = ( W Q ) T ⋅ x ( 2 ) ∈ R d v v ( 3 ) = ( W Q ) T ⋅ x ( 3 ) ∈ R d v v ( 4 ) = ( W Q ) T ⋅ x ( 4 ) ∈ R d v \boldsymbol{v}^{(1)} = (W^{Q}) ^{T} \cdot \boldsymbol{x}^{(1)} \in R^{ d_{v} } \\ \boldsymbol{v}^{(2)} = (W^{Q}) ^{T} \cdot \boldsymbol{x}^{(2)} \in R^{ d_{v} } \\ \boldsymbol{v}^{(3)} = (W^{Q}) ^{T} \cdot \boldsymbol{x}^{(3)} \in R^{ d_{v} } \\ \boldsymbol{v}^{(4)} = (W^{Q}) ^{T} \cdot \boldsymbol{x}^{(4)} \in R^{ d_{v} } \\ v(1)=(WQ)T⋅x(1)∈Rdvv(2)=(WQ)T⋅x(2)∈Rdvv(3)=(WQ)T⋅x(3)∈Rdvv(4)=(WQ)T⋅x(4)∈Rdv
1.5. 定义输入张量的查询条件
q ( 1 ) = ( W Q ) T ⋅ x ( 1 ) ∈ R d k q ( 2 ) = ( W Q ) T ⋅ x ( 2 ) ∈ R d k q ( 3 ) = ( W Q ) T ⋅ x ( 3 ) ∈ R d k q ( 4 ) = ( W Q ) T ⋅ x ( 4 ) ∈ R d k \boldsymbol{q}^{(1)} = (W^{Q}) ^{T} \cdot \boldsymbol{x}^{(1)} \in R^{ d_{k} } \\ \boldsymbol{q}^{(2)} = (W^{Q}) ^{T} \cdot \boldsymbol{x}^{(2)} \in R^{ d_{k} } \\ \boldsymbol{q}^{(3)} = (W^{Q}) ^{T} \cdot \boldsymbol{x}^{(3)} \in R^{ d_{k} } \\ \boldsymbol{q}^{(4)} = (W^{Q}) ^{T} \cdot \boldsymbol{x}^{(4)} \in R^{ d_{k} } \\ q(1)=(WQ)T⋅x(1)∈Rdkq(2)=(WQ)T⋅x(2)∈Rdkq(3)=(WQ)T⋅x(3)∈Rdkq(4)=(WQ)T⋅x(4)∈Rdk
1.6. 自注意力
我们以第一个输入张量为例,它的查询条件为 q ( 1 ) \boldsymbol{q}^{(1)} q(1),它会到所有输入张量的键 q ( i ) \boldsymbol{q}^{(i)} q(i)(包括它自己在内)中进行查询,可以得到一个匹配度的数值,然后利用Softmax函数进行归一化,代表输入张量与第1个输入张量的相关程度,将这个相关程度与各个的值 v ( i ) \boldsymbol{v}^{(i)} v(i)相乘,然后再相加,从而得到第1个输入张量对应的输出值。
- 查询:在所有输入张量的键K中模找,其实就是点积
x ( 1 ) \boldsymbol{x}^{(1)} x(1) | x ( 2 ) \boldsymbol{x}^{(2)} x(2) | x ( 3 ) \boldsymbol{x}^{(3)} x(3) | x ( 4 ) \boldsymbol{x}^{(4)} x(4) | |
---|---|---|---|---|
x ( 1 ) \boldsymbol{x}^{(1)} x(1) | ( q ( 1 ) ) T k ( 1 ) (\boldsymbol{q}^{(1)})^{T} \boldsymbol{k}^{(1)} (q(1))Tk(1) | ( q ( 1 ) ) T k ( 2 ) (\boldsymbol{q}^{(1)})^{T} \boldsymbol{k}^{(2)} (q(1))Tk(2) | ( q ( 1 ) ) T k ( 3 ) (\boldsymbol{q}^{(1)})^{T} \boldsymbol{k}^{(3)} (q(1))Tk(3) | ( q ( 1 ) ) T k ( 4 ) (\boldsymbol{q}^{(1)})^{T} \boldsymbol{k}^{(4)} (q(1))Tk(4) |
- 缩放:除以 d k \sqrt{d_{k} } dk这个标量
x ( 1 ) \boldsymbol{x}^{(1)} x(1) | x ( 2 ) \boldsymbol{x}^{(2)} x(2) | x ( 3 ) \boldsymbol{x}^{(3)} x(3) | x ( 4 ) \boldsymbol{x}^{(4)} x(4) | |
---|---|---|---|---|
x ( 1 ) \boldsymbol{x}^{(1)} x(1) | ( q ( 1 ) ) T k ( 1 ) d k \frac {(\boldsymbol{q}^{(1)})^{T} \boldsymbol{k}^{(1)}} {\sqrt{ d_{k} }} dk(q(1))Tk(1) | ( q ( 1 ) ) T k ( 2 ) d k \frac {(\boldsymbol{q}^{(1)})^{T} \boldsymbol{k}^{(2)}} {\sqrt{ d_{k} }} dk(q(1))Tk(2) | ( q ( 1 ) ) T k ( 3 ) d k \frac {(\boldsymbol{q}^{(1)})^{T} \boldsymbol{k}^{(3)}} {\sqrt{ d_{k} }} dk(q(1))Tk(3) | ( q ( 1 ) ) T k ( 4 ) d k \frac {(\boldsymbol{q}^{(1)})^{T} \boldsymbol{k}^{(4)}} {\sqrt{ d_{k} }} dk(q(1))Tk(4) |
- 归一化:上面计算了第1个输入张量对其他输入张量的查询结果,其为一个标量数值,我们利用softmax函数对其进行归一化,代表第1个输入张量与所有输入张量之间规整化后的关联度系数
s o f t m a x ( ( q ( 1 ) ) T k ( 1 ) d k , ( q ( 1 ) ) T k ( 2 ) d k , ( q ( 1 ) ) T k ( 3 ) d k , ( q ( 1 ) ) T k ( 4 ) d k ) = > c 1 , 1 , c 1 , 2 , c 1 , 3 , c 1 , 4 softmax( \frac {(\boldsymbol{q}^{(1)})^{T} \boldsymbol{k}^{(1)}} {\sqrt{ d_{k} }} , \frac {(\boldsymbol{q}^{(1)})^{T} \boldsymbol{k}^{(2)}} {\sqrt{ d_{k} }}, \frac {(\boldsymbol{q}^{(1)})^{T} \boldsymbol{k}^{(3)}} {\sqrt{ d_{k} }}, \frac {(\boldsymbol{q}^{(1)})^{T} \boldsymbol{k}^{(4)}} {\sqrt{ d_{k} }} ) => c_{1,1}, c_{1,2}, c_{1,3},c_{1,4} softmax(dk(q(1))Tk(1),dk(q(1))Tk(2),dk(q(1))Tk(3),dk(q(1))Tk(4))=>c1,1,c1,2,c1,3,c1,4 - 求输出
z ( 1 ) = ∑ i = 1 4 c 1 , i ⋅ v ( i ) \boldsymbol{z}^{(1)} = \sum_{i=1}^{4} c_{1,i} \cdot \boldsymbol{v}^{(i)} z(1)=i=1∑4c1,i⋅v(i)
按照上面的方法,我们可以求出第2、3、4个输入张量对应的输出值。将其按照如下方式进行堆叠:
[ ( z ( 1 ) ) T ( z ( 2 ) ) T ( z ( 3 ) ) T ( z ( 4 ) ) T ] \begin{bmatrix} (\boldsymbol{z}^{(1)}) ^{T} \\ (\boldsymbol{z}^{(2)}) ^{T} \\ (\boldsymbol{z}^{(3)}) ^{T} \\ (\boldsymbol{z}^{(4)}) ^{T} \end{bmatrix} (z(1))T(z(2))T(z(3))T(z(4))T
假设序列长度为T,张量的维度为 d m d_{m} dm,则其形状为 R T × d m R^{T \times d_{m}} RT×dm。
以上为自注意力的基本原理,但是在实际使用中,为了提高计算效率,我们通常用矩阵运算形式。我们将序列长度为k的输入向量按照如下方式堆叠:
X = [ ( x ( 1 ) ) T ( x ( 2 ) ) T . . . . . . ( x ( i ) ) T . . . . . . ( x ( k ) ) T ] ∈ R T × d m X= \begin{bmatrix} (\boldsymbol{x}^{(1)}) ^{T} \\ (\boldsymbol{x}^{(2)}) ^{T} \\ ...... \\ (\boldsymbol{x}^{(i)}) ^{T} \\ ...... \\ (\boldsymbol{x}^{(k)}) ^{T} \\ \end{bmatrix} \in R^{T \times d_{m}} X= (x(1))T(x(2))T......(x(i))T......(x(k))T ∈RT×dm
键计算
K = X ⋅ W K , ∈ R T × d k K = X \cdot W^{K}, \in R^{T \times d_{k} } K=X⋅WK,∈RT×dk
值计算
V = X ⋅ W V , ∈ R T × d v V = X \cdot W^{V}, \in R^{T \times d_{v} } V=X⋅WV,∈RT×dv
查询计算
Q = X ⋅ W Q , ∈ R T × d k Q = X \cdot W^{Q}, \in R^{T \times d_{k} } Q=X⋅WQ,∈RT×dk
计算输出值为:
Z = s o f t m a x ( Q K T d k ) V , ∈ R T × d v Z = softmax \bigg(\frac{ QK^{T} }{ \sqrt{ d_{k} } } \bigg) V, \in R^{T \times d_{v} } Z=softmax(dkQKT)V,∈RT×dv
在实际应用中,通常是多个自注意力头,例如原始论文中就是8个头,对8个头分别做上述计算,得到输出为:
Z 0 , Z 1 , Z 2 , Z 3 , Z 4 , Z 5 , Z 6 , Z 7 Z_{0}, Z{1}, Z{2}, Z{3}, Z{4}, Z{5}, Z{6}, Z{7} Z0,Z1,Z2,Z3,Z4,Z5,Z6,Z7
将其拼接为 Z ∈ R T × ( h × d v ) Z \in R^{T \times (h\times d_{v})} Z∈RT×(h×dv),我们定义输出权重 W o ∈ R ( h × d v ) × d m W^{o} \in R^{(h \times d_{v}) \times d_{m}} Wo∈R(h×dv)×dm,则:
Y = Z ⋅ W o , ∈ R T × d m Y = Z \cdot W^{o}, \in R^{T \times d_{m}} Y=Z⋅Wo,∈RT×dm