【手撕代码】从零实现Transformer的Scaled Dot-Product Attention 和 Multi-Head Attention(上)

在深度学习和自然语言处理(NLP)领域,Transformer 架构已经成为主流模型的核心组件。其中,Scaled Dot-Product Attention 和 Multi-Head Attention(MHA) 是 Transformer 的核心机制之一。本文将从零开始,手写实现这两个关键模块,并深入讲解其原理和实现细节。


一、Scaled Dot-Product Attention

1.1 原理回顾

Scaled Dot-Product Attention 是注意力机制中最基本的形式,缩放点积注意力,其公式如下:

\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

其中:

  • Q :Query 向量

  • K :Key 向量

  • V :Value 向量

  • d_k:Key 向量的维度

  • 缩放因子  \sqrt{d_k}是为了防止内积过大导致 softmax 梯度消失

此外,我们还可以通过 mask 来屏蔽无效位置(如 padding 或未来信息),在 softmax 前将无效位置的值设为极小值(如 -1e9)。


1.2 PyTorch 实现

import torch
import torch.nn as nn

# 设置随机种子(固定随机性)
torch.manual_seed(42) 

class ScaledDotProductAttention(nn.Module):
  def __init__(self, d_k, dropout=None):
    super().__init__()
    self.scale = d_k
    self.softmax = nn.Softmax(dim=2)
    # 修改:如果 dropout=None,则使用 nn.Identity() 代替
    self.dropout = nn.Dropout(dropout) if dropout is not None else nn.Identity()


  def forward(self,q,k,v,mask):
    # q: (batch, n_head, seq_len_q, d_k)
    # k: (batch, n_head, seq_len_k, d_k)
    # v: (batch, n_head, seq_len_v, d_v)
    # mask: (batch, n_head, seq_len_q, seq_len_k)
    u = torch.matmul(q,k.transpose(-2,-1))/self.scale
    
    # 应用 mask(如果存在)
    if mask is not None:
        u = u.masked_fill(mask,-1e9)
    
    # softmax 归一化
    att_weight = self.softmax(u)

    # 应用 dropout(如果存在)
    att_weight = self.dropout(att_weight)

    output = torch.matmul(att_weight, v)
    return att_weight,output

1.3 使用示例

batch=1
n_head = 2
n_q, n_k, n_v = 2, 4, 4
d_q, d_k, d_v = 8, 8, 64
q=torch.randn(batch,n_head,n_q,d_q)
k=torch.randn(batch,n_head,n_k,d_k)
v=torch.randn(batch,n_head,n_v,d_v)
mask= torch.zeros(batch,n_head,n_q,n_k).bool()

attention = ScaledDotProductAttention(d_k** 0.5)
attn_weights,output = attention(q,k,v,mask)

print("Output shape:", output.shape)
print("Attention weights shape:", attn_weights.shape)  

二、Multi-Head Attention(MHA)

2.1 原理回顾

Multi-Head Attention 的核心思想是将 Q/K/V 投影到多个不同的表示空间中,分别进行注意力计算,再将结果拼接并通过一个线性变换输出最终结果。

其结构如下:

  1. 通过线性层将输入 Q/K/V 投影到d_{\text{model}}维度

  2. 将每个头的 Q/K/V 拆分成多个头(head)

  3. 对每个头分别调用 Scaled Dot-Product Attention

  4. 拼接所有头的输出,再通过一个线性层输出最终结果


2.2 PyTorch 实现

class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, n_head, d_model):
        super().__init__()
        self.n_head = n_head
        self.d_model = d_model
        self.d_k = d_model // n_head

        # 线性投影层
        self.linear_q = nn.Linear(input_dim, d_model)
        self.linear_k = nn.Linear(input_dim, d_model)
        self.linear_v = nn.Linear(input_dim, d_model)

        # 注意力机制
        self.attention = ScaledDotProductAttention(scale=self.d_k**0.5)

        # 输出线性层
        self.linear_out = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size()

        # 线性变换
        q = self.linear_q(x)  # (batch, seq_len, d_model)
        k = self.linear_k(x)
        v = self.linear_v(x)

        # 拆分成多个头
        q = q.view(batch_size, seq_len, self.n_head, self.d_k).permute(0, 2, 1, 3)
        k = k.view(batch_size, seq_len, self.n_head, self.d_k).permute(0, 2, 1, 3)
        v = v.view(batch_size, seq_len, self.n_head, self.d_k).permute(0, 2, 1, 3)

        # 添加 mask 维度
        if mask is not None:
            mask = mask.unsqueeze(1)  # (batch, 1, seq_len_q, seq_len_k)

        # 多头注意力计算
        attn_weights, output = self.attention(q, k, v, mask)

        # 合并多个头
        output = output.permute(0, 2, 1, 3).contiguous()
        output = output.view(batch_size, seq_len, self.d_model)

        # 最终线性层
        output = self.linear_out(output)

        return output, attn_weights

2.3 使用示例

import torch

# 设置参数
batch_size = 2
seq_len = 10
input_dim = 512
n_head = 8
d_model = 512

# 构建输入
x = torch.randn(batch_size, seq_len, input_dim)
mask = torch.zeros(batch_size, seq_len, seq_len).bool()

# 实例化 MHA
mha = MultiHeadAttention(input_dim, n_head, d_model)

# 前向传播
output, attn_weights = mha(x, mask)

print("Output shape:", output.shape)  # (2, 10, 512)
print("Attention weights shape:", attn_weights.shape)  # (2, 8, 10, 10)

三、关键点总结

模块

关键点

Scaled Dot-Product Attention

缩放因子为 \sqrt{d_k} ,mask 使用 masked_fill 填充 -1e9,softmax 而非 sigmoid

Multi-Head Attention

拆头使用 view 和 permute,mask 需要扩展维度,合并头后通过线性层输出

模型结构设计

__init__ 中定义模型参数,forward 中处理前向传播数据

性能优化建议

使用 einops.rearrange 替代 permute 可提升可读性与效率


四、附加挑战建议(进阶)

  1. 支持双向/单向 mask:

    • 编码器中使用 padding mask

    • 解码器中使用 future mask(防止看到未来信息)

  2. 数值稳定性:在 softmax 前做 x - x.max() 操作,防止数值溢出。

  3. 实现GQA和kv cache版本

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值