动手实现Multi-Head Attention

MultiHeadAttention 是 Transformer 模型中的一个核心组件,它允许模型在处理序列的每个位置时同时考虑来自多个“视角”(即头部)的信息。这样做可以提高模型对不同位置关系的理解能力。

重点讲解

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

主要步骤:

  • 线性变换得到QKV,并将QKV分割为多头
  • 计算缩放点积注意力(注意mask可选)
  • 拼接多头
  • 最后再进行一次线性变换

代码实现

下面,我将使用 PyTorch 框架实现一个基本的 MultiHeadAttention 模块。

import torch
import torch.nn as nn
import torch.nn.functional as F

import math


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.depth = d_model // num_heads
        
        # 定义线性层和输出线性层
        self.query_linear = nn.Linear(d_model, d_model)
        self.key_linear = nn.Linear(d_model, d_model)
        self.value_linear = nn.Linear(d_model, d_model)
        self.final_linear = nn.Linear(d_model, d_model)

    def split_heads(self, x, batch_size):
        """分割最后一个维度到 (num_heads, depth).
        转置结果使得形状为 (batch_size, num_heads, seq_length, depth)
        """
        x = x.view(batch_size, -1, self.num_heads, self.depth)
        return x.permute(0, 2, 1, 3)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # 1. 线性层和分割到多头
        query = self.split_heads(self.query_linear(query), batch_size)
        key = self.split_heads(self.key_linear(key), batch_size)
        value = self.split_heads(self.value_linear(value), batch_size)
        
        # 2. 缩放点积注意力
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.depth)
        if mask is not None:
            scores = scores.masked_fill(mask == True, float('-inf'))
        attention_weights = F.softmax(scores, dim=-1)
        
        # 3. 将注意力权重应用到值上
        output = torch.matmul(attention_weights, value)
        
        # 4. 连接头部
        output = output.permute(0, 2, 1, 3).contiguous()
        output = output.view(batch_size, -1, self.d_model)
        
        # 5. 最后一次线性变换
        output = self.final_linear(output)
        
        return output

流程图(维度变换示意图)

在这里插入图片描述

self-attention示例

   d_model = 512  # 模型维度
   num_heads = 8  # 头数
   mha = MultiHeadAttention(d_model, num_heads)

   # 创建随机数据
   batch_size = 4
   seq_length = 60
   x = torch.rand(batch_size, seq_length, d_model)  # 输入假设维度为 (batch_size, seq_length, d_model)

   output = mha(x, x, x)  # 自注意力机制,qkv的输入相同;而cross-attention中,query来自decoder,kv来自encoder
   print(output.shape)  

加入mask示例

解码器的自注意力层需要确保当前位置只能注意到前面的位置(包括当前位置),而不是未来的位置。这通常通过一个未来位置掩码实现,它是一个下三角矩阵

import torch

def generate_square_subsequent_mask(seq_len):
    """生成一个未来步骤掩码,用于解码器中防止看到未来信息。"""
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()   # diagonal 控制对角线开始的位置
    return mask

d_model = 512  # 模型维度
num_heads = 8  # 头数
mha = MultiHeadAttention(d_model, num_heads)

# 创建随机数据
batch_size = 4
seq_length = 60
x = torch.rand(batch_size, seq_length, d_model)  # 输入假设维度为 (batch_size, seq_length, d_model)

# 生成掩码并将其应用于解码器的自注意力层
future_mask = generate_square_subsequent_mask(seq_length).to(x.device)
output = mha(x, x, x, mask=future_mask)  # 自注意力机制
print(output.shape)  # 应为 (batch_size, seq_length, d_model)

注意广播机制
在 PyTorch 中,masked_fill 函数可以很灵活地处理维度差异情况,通过广播(broadcasting)机制来匹配维度。

在这里插入图片描述

  • 4
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值