这也是借鉴了 github 的写法;这种写法也是通过外部传入
Q
,
K
,
V
Q,K,V
Q,K,V ,写法也比较简洁,不太好理解
import torch.nn as nn
from.single import Attention
classMultiHeadedAttention(nn.Module):"""
Take in model size and number of heads.
"""def__init__(self, h, d_model, dropout=0.1):super().__init__()assert d_model % h ==0# We assume d_v always equals d_k
self.d_k = d_model // h
self.h = h
self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model)for _ inrange(3)])
self.output_linear = nn.Linear(d_model, d_model)
self.attention = Attention()
self.dropout = nn.Dropout(p=dropout)defforward(self, query, key, value, mask=None):
batch_size = query.size(0)# 1) Do all the linear projections in batch from d_model => h x d_k
query, key, value =[l(x).view(batch_size,-1, self.h, self.d_k).transpose(1,2)for l, x inzip(self.linear_layers,(query, key, value))]# 2) Apply attention on all the projected vectors in batch.
x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)# 3) "Concat" using a view and apply a final linear.
x = x.transpose(1,2).contiguous().view(batch_size,-1, self.h * self.d_k)return self.output_linear(x)