多头注意力机制(Multi-Head Attention)是Transformer架构中的核心组件,它在自然语言处理、图像识别等领域取得了显著的成果。多头注意力机制通过将输入数据划分为多个“头”,使模型能够并行捕捉输入数据中的不同特征和模式。
这是一段MHA的代码:
# Define a multi-head attention class
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, d_k, d_v, n_head, dropout=0.1):
super(MultiHeadAttention, self).__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.w_qs = nn.Linear(d_model, n_head * d_k)
self.w_ks = nn.Linear(d_model, n_head * d_k)
self.w_vs = nn.Linear(d_model, n_head * d_v)
self.fc = nn.Linear(n_head * d_v, d_model)
self.attention = ScaledDotProductAttention()
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v, attn_mask=None):
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, _ = q.size()
sz_b, l