# step 7:构建scaled self-attention
def scaled_dot_product_attention(Q, K, V, attn_mask):
# shape of Q,K,V:(batch_size*num_head,seq_len,model_dim/num_head)
score = torch.bmm(Q, K.transpose(-2, -1)) / torch.sqrt(model_dim)
masked_score = score.masked_fill(attn_mask)
prob = F.softmax(masked_score, -1)
context = torch.bmm(prob, V)
return context
【Transformer】multi-head self-attention
最新推荐文章于 2024-11-09 09:16:41 发布