多头注意力(Multi‑Head Attention)

1. 多头注意力(Multi‑Head Attention)原理

设输入序列表示为矩阵 X ∈ R B × L × d model X\in\mathbb{R}^{B\times L\times d_{\text{model}}} XRB×L×dmodel,其中

  • B B B:批大小(batch size),
  • L L L:序列长度(sequence length),
  • d model d_{\text{model}} dmodel:模型隐层维度(model dimension)。

多头注意力基于对缩放点乘注意力的并行化扩展,引入了 h h h 个注意力头(heads),每个头在不同子空间中学习不同的表示。

1.1 线性映射与切分

我们首先为每个头定义三组可学习权重:
W i Q ∈ R d model × d k , W i K ∈ R d model × d k , W i V ∈ R d model × d v , i = 1 , … , h W_i^Q \in \mathbb{R}^{d_{\text{model}}\times d_k},\quad W_i^K \in \mathbb{R}^{d_{\text{model}}\times d_k},\quad W_i^V \in \mathbb{R}^{d_{\text{model}}\times d_v}, \quad i=1,\dots,h WiQRdmodel×dk,WiKRdmodel×dk,WiVRdmodel×dv,i=1,,h
其中

  • h h h:头数(number of heads),
  • d k d_k dk:每个头中 Query/Key 的维度(key/query dimension),
  • d v d_v dv:每个头中 Value 的维度(value dimension),
  • 通常 d model = h × d k d_{\text{model}}=h\times d_k dmodel=h×dk 且取 d v = d k d_v = d_k dv=dk

对输入 X X X 进行投影,得到第 i i i 个头的查询、键、值:
Q i = X   W i Q , K i = X   W i K , V i = X   W i V Q_i = X\,W_i^Q,\quad K_i = X\,W_i^K,\quad V_i = X\,W_i^V Qi=XWiQ,Ki=XWiK,Vi=XWiV
其中

  • Q i ∈ R B × L × d k Q_i \in \mathbb{R}^{B\times L\times d_k} QiRB×L×dk
  • K i ∈ R B × L × d k K_i \in \mathbb{R}^{B\times L\times d_k} KiRB×L×dk
  • V i ∈ R B × L × d v V_i \in \mathbb{R}^{B\times L\times d_v} ViRB×L×dv

1.2 缩放点乘注意力

对第 i i i 个头,分别对所有位置做点积注意力:

  1. 打分矩阵
    S c o r e i = Q i   K i ⊤ ∈ R B × L × L \mathrm{Score}_i = Q_i\,K_i^\top \quad\in\mathbb{R}^{B\times L\times L} Scorei=QiKiRB×L×L
  2. 缩放
    S c o r e ~ i = S c o r e i d k \widetilde{\mathrm{Score}}_i = \frac{\mathrm{Score}_i}{\sqrt{d_k}} Score i=dk Scorei
  3. Softmax 归一化
    A i = s o f t m a x ( S c o r e ~ i ) ∈ R B × L × L A_i = \mathrm{softmax}\bigl(\widetilde{\mathrm{Score}}_i\bigr) \quad\in\mathbb{R}^{B\times L\times L} Ai=softmax(Score i)RB×L×L
  4. 加权求和
    h e a d i = A i   V i ∈ R B × L × d v \mathrm{head}_i = A_i\,V_i \quad\in\mathbb{R}^{B\times L\times d_v} headi=AiViRB×L×dv

1.3 拼接与线性变换

将所有头的输出在最后一维拼接,再做一次线性投影:
C o n c a t = [ h e a d 1 , … , h e a d h ] ∈ R B × L × ( h   d v ) \mathrm{Concat} = \bigl[\mathrm{head}_1,\dots,\mathrm{head}_h\bigr] \quad\in\mathbb{R}^{B\times L\times (h\,d_v)} Concat=[head1,,headh]RB×L×(hdv)
定义输出权重矩阵
W O ∈ R ( h   d v ) × d model W^O\in\mathbb{R}^{(h\,d_v)\times d_{\text{model}}} WOR(hdv)×dmodel
最终输出:
M u l t i H e a d ( X ) = C o n c a t    W O ∈ R B × L × d model \mathrm{MultiHead}(X) = \mathrm{Concat}\;W^O \quad\in\mathbb{R}^{B\times L\times d_{\text{model}}} MultiHead(X)=ConcatWORB×L×dmodel


2. PyTorch 代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, h: int):
        """
        d_model: 模型维度 d_model
        h: 注意力头数 h
        """
        super().__init__()
        assert d_model % h == 0, "d_model 必须能被 h 整除"
        self.d_model = d_model      # d_model
        self.h = h                  # h
        self.d_k = d_model // h     # d_k = d_model / h
        self.d_v = self.d_k         # d_v 通常等于 d_k

        # 投影矩阵 W_i^Q, W_i^K, W_i^V,实际上合并为一个大矩阵后在 forward 再切分
        self.w_q = nn.Linear(d_model, d_model)  # 同时生成 h 个 Q 投影
        self.w_k = nn.Linear(d_model, d_model)  # 同时生成 h 个 K 投影
        self.w_v = nn.Linear(d_model, d_model)  # 同时生成 h 个 V 投影

        # 输出线性变换 W^O
        self.w_o = nn.Linear(d_model, d_model)

    def forward(self, X: torch.Tensor, mask: torch.Tensor = None):
        """
        X: 输入张量,形状 (B, L, d_model)
        mask: 可选掩码,形状 (B, 1, L, L) 或 (B, L, L)
        """
        B, L, _ = X.size()

        # 1. 线性映射到 Q, K, V,然后切分 h 头
        #    先得到 (B, L, h*d_k),再 view/transpose 为 (B, h, L, d_k)
        Q = self.w_q(X).view(B, L, self.h, self.d_k).transpose(1, 2)
        K = self.w_k(X).view(B, L, self.h, self.d_k).transpose(1, 2)
        V = self.w_v(X).view(B, L, self.h, self.d_k).transpose(1, 2)
        # 此时 Q, K, V 形状均为 (B, h, L, d_k)

        # 2. 计算点积注意力
        #    scores = Q @ K^T  -> (B, h, L, L)
        scores = torch.matmul(Q, K.transpose(-2, -1))  
        #    缩放:除以 sqrt(d_k)
        scores = scores / math.sqrt(self.d_k)
        #    可选掩码:将被屏蔽位置设为 -inf 
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        #    Softmax 归一化 -> (B, h, L, L)
        A = F.softmax(scores, dim=-1)

        #    加权求和 -> head_i 形状 (B, h, L, d_k)
        heads = torch.matmul(A, V)

        # 3. 拼接 h 个头:transpose 回 (B, L, h, d_k) 再 reshape
        concat = heads.transpose(1, 2).contiguous().view(B, L, self.h * self.d_k)
        #    concat 形状 (B, L, h*d_k) == (B, L, d_model)

        # 4. 最后一次线性变换 W^O
        output = self.w_o(concat)  # -> (B, L, d_model)
        return output, A  # 返回输出及注意力权重 A

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值