SelfAttention操作
从单个字的角度:
q
i
=
h
i
W
Q
,
k
j
=
h
j
W
K
,
v
j
=
h
j
W
V
q_i = h_iW_Q,k_j = h_jW_K,v_j = h_jW_V
qi=hiWQ,kj=hjWK,vj=hjWV
e
i
j
=
q
i
k
j
T
e_{ij} = q_ik_j^T
eij=qikjT
α
i
=
S
o
f
t
m
a
x
(
[
e
i
,
1
,
.
.
.
,
e
i
,
T
]
)
\alpha_i = Softmax([e_{i,1},...,e_{i,T}])
αi=Softmax([ei,1,...,ei,T])
h
i
′
=
(
∑
j
=
1
T
α
i
,
j
v
j
)
W
0
h'_i = (\sum_{j=1}^T \alpha_{i,j}v_j)W_0
hi′=(∑j=1Tαi,jvj)W0
矩阵的形式:
Q
=
H
W
Q
,
K
=
H
W
K
,
V
=
H
W
V
Q = HW_Q,K = HW_K,V = HW_V
Q=HWQ,K=HWK,V=HWV
E
=
Q
K
T
E = QK^T
E=QKT
E
′
=
S
o
f
t
m
a
x
(
E
)
E' = Softmax(E)
E′=Softmax(E)
H
′
=
E
′
V
H' = E'V
H′=E′V
单头selfAttention
import math
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self,d_model,d_head):
super(SelfAttention,self).__init__()
self.w_q = nn.Linear(d_model,d_head)
self.w_k = nn.Linear(d_model,d_head)
self.w_v = nn.Linear(d_model,d_head)
self.w_o = nn.Linear(d_head,d_model)
def forward(self,x):
# x:[batch_size,max_len,model_dim]
# q,k,v:[batch_size,max_len,d_head]
q = self.w_q(x)
k = self.w_k(x)
v = self.w_v(x)
attn_score = torch.matmul(q,k.permute(0,2,1)) # 注意这里不是reshape
attn_score = torch.softmax(attn_score,dim = -1) # [batch_size,max_len,max_len]
output = torch.matmul(attn_score,v) # [batch_size,max_len,d_head]
return self.w_o(output)
x = torch.randn(3,9,100)
model = SelfAttention(100,80)
model(x).shape
多头selfAttention
# 多头selfattention
class MultiHeadSelfAttention(nn.Module):
def __init__(self,d_model = 768,d_head = 64):
super(MultiHeadSelfAttention,self).__init__()
assert d_model % d_head == 0
self.w_q = nn.Linear(d_model,d_model)
self.w_k = nn.Linear(d_model,d_model)
self.w_v = nn.Linear(d_model,d_model)
self.w_o = nn.Linear(d_model,d_model)
self.n_heads = int(d_model // d_head)
self.d_model = d_model
self.d_head = d_head
def forward(self,x,mask = None):
batch_size = x.shape[0]
max_len = x.shape[1]
q = self.w_q(x).view(batch_size,max_len,self.n_heads,self.d_head)
k = self.w_k(x).view(batch_size,max_len,self.n_heads,self.d_head)
v = self.w_v(x).view(batch_size,max_len,self.n_heads,self.d_head)
q = q.permute(0,2,1,3)
k = k.permute(0,2,1,3)
v = v.permute(0,2,1,3) # [batch_size,num_head,max_len,d_head]
attn_score = torch.matmul(q,k.permute(0,1,3,2))
if mask is not None:
mask = mask.unsqueeze(1).unsqueeze(-1) # [batch_size,1,max_len,1]
attn_score = attn_score.masked_fill(mask == 0,-1e-25)
attn_score = torch.softmax(attn_score,-1) # [batch_size,num_head,max_len,max_len]
out = torch.matmul(attn_score,v).permute(0,2,1,3)
out = out.contiguous().view(batch_size,max_len,-1)
return self.w_o(out)
if __name__ == "__main__":
x = torch.randn(2, 9, 768)
mask = torch.tensor([
[1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0, 0],
]).bool()
model = MultiHeadSelfAttention()
print(model(x,mask).shape)