注意力机制优化简明图解
1. 多头注意力(MHA)
图示:
Input --> [Attention Head 1]
--> [Attention Head 2]
--> [Attention Head 3]
--> ...
--> [Attention Head N]
--> [Concatenate] --> Output
公式:
Output
=
Concat
(
head
1
,
head
2
,
…
,
head
N
)
\text{Output} = \text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_N)
Output=Concat(head1,head2,…,headN)
head
i
=
Attention
(
Q
,
K
,
V
)
\text{head}_i = \text{Attention}(Q, K, V)
headi=Attention(Q,K,V)
2. 多查询注意力(MQA)
图示:
Input --> [Shared Keys & Values]
--> [Attention Head 1]
--> [Attention Head 2]
--> [Attention Head 3]
--> ...
--> [Concatenate] --> Output
公式:
Output
=
Concat
(
head
1
,
head
2
,
…
,
head
N
)
\text{Output} = \text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_N)
Output=Concat(head1,head2,…,headN)
head
i
=
Attention
(
Q
,
K
shared
,
V
shared
)
\text{head}_i = \text{Attention}(Q, K_{\text{shared}}, V_{\text{shared}})
headi=Attention(Q,Kshared,Vshared)
3. 分组查询注意力(GQA)
图示:
Input --> [Attention Group 1]
--> [Attention Group 2]
--> ...
--> [Concatenate] --> Output
公式:
Output
=
Concat
(
group
1
,
group
2
,
…
,
group
M
)
\text{Output} = \text{Concat}(\text{group}_1, \text{group}_2, \ldots, \text{group}_M)
Output=Concat(group1,group2,…,groupM)
group
j
=
Attention
(
Q
group
j
,
K
group
j
,
V
group
j
)
\text{group}_j = \text{Attention}(Q_{\text{group}_j}, K_{\text{group}_j}, V_{\text{group}_j})
groupj=Attention(Qgroupj,Kgroupj,Vgroupj)
4. 多头潜在注意力(MLA)
图示:
Input --> [Compressed Keys & Values]
--> [Attention Head 1]
--> [Attention Head 2]
--> [Attention Head 3]
--> ...
--> [Concatenate] --> Output
公式:
Output
=
Concat
(
head
1
,
head
2
,
…
,
head
N
)
\text{Output} = \text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_N)
Output=Concat(head1,head2,…,headN)
head
i
=
Attention
(
Q
,
K
compressed
,
V
compressed
)
\text{head}_i = \text{Attention}(Q, K_{\text{compressed}}, V_{\text{compressed}})
headi=Attention(Q,Kcompressed,Vcompressed)
低秩键值联合压缩公式:
K
compressed
=
U
K
⋅
S
K
⋅
V
K
T
K_{\text{compressed}} = U_K \cdot S_K \cdot V_K^T
Kcompressed=UK⋅SK⋅VKT
V
compressed
=
U
V
⋅
S
V
⋅
V
V
T
V_{\text{compressed}} = U_V \cdot S_V \cdot V_V^T
Vcompressed=UV⋅SV⋅VVT
图示概述
- MHA: 每个头独立操作,最终结果拼接。
- MQA: 多个头共享键和值,只计算一次查询,减少计算量。
- GQA: 查询分组,每组共享键和值,进一步减少计算量。
- MLA: 键和值进行压缩,减少内存和计算需求。
这些方法通过不同的策略优化注意力机制,提高了计算效率,降低了内存消耗,使Transformer模型在实际应用中更加高效。