- 看见网上有要手撕注意力的面经,自己开始写,以为很简单,实际上自己菜的要死,遂写本博客。
公式
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
s
o
f
t
m
a
x
(
Q
K
t
d
)
V
Attention(Q,K,V) = softmax(\frac{QK^t}{\sqrt d})V
Attention(Q,K,V)=softmax(dQKt)V
H
e
a
d
i
=
A
t
t
e
n
t
i
o
n
(
Q
W
i
Q
,
K
W
i
K
,
V
W
i
V
)
Head_i = Attention(QW_i^Q,KW_i^K,VW_i^V)
Headi=Attention(QWiQ,KWiK,VWiV)
M
H
A
(
Q
,
K
,
V
)
=
C
o
n
c
a
t
(
H
e
a
d
1
,
…
…
,
H
e
a
d
h
)
W
O
MHA(Q,K,V) = Concat(Head_1,……,Head_h)W^O
MHA(Q,K,V)=Concat(Head1,……,Headh)WO
O为最终输出变幻矩阵
代码
class MultiHeadAttention(nn.Module):
def __init__(self, heads, d_model, dropout=0.1):
super().__init__()
self.d_model = d_model # 模型的维度
self.d_k = d_model // heads # 每个头的维度
self.h = heads # 头的数量
# 以下三个是线性层,用于处理Q(Query),K(Key),V(Value)
self.q_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout) # Dropout层
self.out = nn.Linear(d_model, d_model) # 输出层
def attention(self, q, k, v, d_k, mask=None, dropout=None):
# torch.matmul是矩阵乘法,用于计算query和key的相似度
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
mask = mask.unsqueeze(1) # 在第一个维度增加维度
scores = scores.masked_fill(mask == 0, -1e9) # 使用mask将不需要关注的位置设置为一个非常小的数
# 对最后一个维度进行softmax运算,得到权重
scores = F.softmax(scores, dim=-1)
if dropout is not None:
scores = dropout(scores) # 应用dropout
output = torch.matmul(scores, v) # 将权重应用到value上
return output
def forward(self, q, k, v, mask=None):
bs = q.size(0) # 获取batch_size
# 将Q, K, V通过线性层处理,然后分割成多个头
k = self.k_linear(k).view(bs, -1, self.h, self.d_k) # batchsize * sentence_length * head * head_dim
q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
# batchsize * head * sentence_length * head_dim
k = k.transpose(1, 2)
q = q.transpose(1, 2)
v = v.transpose(1, 2)
# 调用attention函数计算输出
scores = attention(q, k, v, self.d_k, mask, self.dropout)
# 重新调整张量的形状,并通过最后一个线性层
concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
'''
在PyTorch中,张量的存储方式是连续的(contiguous),也就是说张量的元素在内存中是按照一定顺序存储的。但是,在进行某些操作后,例如通过索引、转置、视图变换等操作,张量可能会变得不再连续。
'''
output = self.out(concat) # 最终输出
return output