论文中描述的是对QKV做多次线性变换,几个头做几次,默认8次,每一次的维度变换为512->64,然后在单独的头里做点积attention,在把每个头的结果拼起来,维度还原到512
总参数量:512×64×8(n_head)×3
class MultiHeadAttention(nn.Module):
''' Multi-Head Attention module '''
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
super().__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
def forward(self, q, k, v, mask=None):
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
residual = q
# Pass through the pre-attention projection: b x lq x (n*dv)
# Separate different heads: b x lq x n x dv
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
而代码中实现的方式是,对QKV各做一次线性变换512->512,然后再切成多个头64×8,这边的效果和分别做多次(头的个数)映射效果是相同的,然后再在单独的头里做点积attention,再将多个头拼起来还原到512
总参数量:512×512×3