1. 多头注意力
1.1 结构图
在实践中,当给定相同的查询,键和值的集合时,我们希望模型可以基于相同的注意力机制学习到不同的行为,然后将不同的行为作为知识组合起来,捕获序列内各种范围的依赖关系。(短距离依赖和长距离依赖关系)。因此,允许注意力机制组合使用查询、键和值的不同子空间表示可能是有益的。我们分两步思考
- 用 h 1 , . . . h h h_1,...h_h h1,...hh个注意力进行学习
- 将
h
1
,
.
.
.
,
h
h
h_1,...,h_h
h1,...,hh进行注意力汇聚输出
1.2 相关公式
给定查询
q
∈
R
d
q
q\in R^{d_q}
q∈Rdq,键
k
∈
R
d
k
k\in R^{d_k}
k∈Rdk,和值
v
∈
R
d
v
v\in R^{d_v}
v∈Rdv,每个注意力头
h
i
(
i
=
1
,
.
.
.
,
h
)
h_i(i=1,...,h)
hi(i=1,...,h)的计算方法为:
h
i
=
f
(
W
i
(
q
)
q
,
W
i
(
k
)
k
,
W
i
(
v
)
v
)
∈
R
p
v
(1)
h_i=f(W_i^{(q)}q,W_i^{(k)}k,W_i^{(v)}v)\in R^{p_v}\tag1
hi=f(Wi(q)q,Wi(k)k,Wi(v)v)∈Rpv(1)
其中可学习的参数包括
W
i
(
q
)
∈
R
p
q
×
d
q
W_i^{(q)}\in R^{p_q \times d_q}
Wi(q)∈Rpq×dq,
W
i
(
k
)
∈
R
p
k
×
d
k
W_i^{(k)}\in R^{p_k \times d_k}
Wi(k)∈Rpk×dk,
W
i
(
v
)
∈
R
p
v
×
d
v
W_i^{(v)}\in R^{p_v \times d_v}
Wi(v)∈Rpv×dv,以及代表注意力汇聚的函数 f,f可以使加性注意力和缩放点积注意力,多头注意力的输出需要经过另一线性变换,它对应着 h 个头连结后的结果。因此其学习参数是
W
o
∈
R
p
o
×
h
p
v
W_o\in R^{p_o \times hp_v}
Wo∈Rpo×hpv
W
o
[
h
1
⋮
h
h
]
∈
R
p
o
(2)
W_o\begin{bmatrix}h_1\\\vdots\\h_h\end{bmatrix}\in R^{p_o}\tag2
Wo⎣⎢⎡h1⋮hh⎦⎥⎤∈Rpo(2)
1.3 源码分析
# -*- coding: utf-8 -*-
# @Project: zc
# @Author: zc
# @File name: MultiHeadAttention_test
# @Create time: 2022/2/25 9:11
import torch
from torch import nn
from d2l import torch as d2l
class MultiHeadAttention(nn.Module):
"""
作用:将输入的矩阵X按照特征维度进行分割为num_heads个
"""
# key_size=100;query_size=100,value_size=100,value_size=100
# num_hiddens=100;num_head=5,dropout=0.5
def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
# 100->100
self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
# 100->100
self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
# 100->100
self.W_v = nn.Linear(value_size, num_hiddens, bias=False)
# 100->100
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=False)
def forward(self, queries, keys, values, valid_lens):
# 输入 queies=(2,4,100);keys=(2,6,100);values=(2,6,100)
# valid_lens=torch.tensor([3,2])
# 输出 queries=(2,4,100) -> (2,4,5,20) -> (2,5,4,20) -> (10,4,20)
# 输出 keys=(2,6,100) ->(2,6,5,20) -> (2,5,6,20) -> (10,6,20)
# 输出 values=(2,6,100) -> (2,6,5,20) -> (2,5,6,20) -> (10,6,20)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
# queries=(10,6,20);keys=(10,6,20);values(10,6,20)
# output=(10,4,20)
output = self.attention(queries, keys, values, valid_lens)
# (10,4,20) -> (2,5,4,20) -> (2,4,5,20) -> (2,4,100)=output_concat
output_concat = transpose_output(output, self.num_heads)
# return (2,4,100) -> (2,4,100)
return self.W_o(output_concat)
def transpose_qkv(X, num_heads):
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
X = X.permute(0, 2, 1, 3)
return X.reshape(-1, X.shape[2], X.shape[3])
def transpose_output(X, num_heads):
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
attention.eval()
print(attention)
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])
# x=(2,4,100);y=(2,6,100)
x = torch.ones((batch_size, num_queries, num_hiddens))
y = torch.ones((batch_size, num_kvpairs, num_hiddens))
# attention(x,y,y,valid_lens).shape=(2,4,100)
print(attention(x, y, y, valid_lens).shape)
MultiHeadAttention(
(attention): DotProductAttention(
(dropout): Dropout(p=0.5, inplace=False)
)
(W_q): Linear(in_features=100, out_features=100, bias=False)
(W_k): Linear(in_features=100, out_features=100, bias=False)
(W_v): Linear(in_features=100, out_features=100, bias=False)
(W_o): Linear(in_features=100, out_features=100, bias=False)
)
torch.Size([2, 4, 100])
1.4 小结
- 为了避免我们用 for loop循环,我们先将 queries,keys,values 按照 num_heads 打散,再进行点积注意力运算,再concat合并,最后输出,这样我们就可以不需要循环了,通过大矩阵的计算来避免循环的使用,提高了计算的效率
2. 自注意力
给定一个由词元组成的输入序列
x
1
,
.
.
.
,
x
n
x_1,...,x_n
x1,...,xn,其中任意
x
i
∈
R
d
(
1
≤
i
≤
n
)
x_i\in R^d(1\leq i \leq n)
xi∈Rd(1≤i≤n).该序列的自注意力输出为一个长度相同的序列
y
1
,
.
.
.
,
y
n
y_1,...,y_n
y1,...,yn,其中
y
i
=
f
(
x
i
,
(
x
1
,
x
1
)
,
.
.
.
,
(
x
n
,
x
n
)
)
∈
R
d
y_i=f(x_i,(x_1,x_1),...,(x_n,x_n))\in R^d
yi=f(xi,(x1,x1),...,(xn,xn))∈Rd
2.1源码
# -*- coding: utf-8 -*-
# @Project: zc
# @Author: zc
# @File name: self_attention_test
# @Create time: 2022/2/27 9:55
import torch
from torch import nn
from d2l import torch as d2l
num_hiddens, num_heads = 100, 5
attetion = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
attetion.eval()
print(attetion)
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
x = torch.ones((batch_size,num_queries,num_hiddens))
print(f"attetion(x,x,x,valid_lens).shape={attetion(x,x,x,valid_lens).shape}")
2.2 小结
自注意力机制运用了多头注意力机制,只不过区别在于自注意力机制的 queries,keys,values是相同的。