手撕multi-head self attention 代码

在深度学习和自然语言处理领域,多头自注意力(Multi-Head Self-Attention)机制是Transformer模型中的核心组件之一。它允许模型在处理序列数据时,能够同时关注序列中的不同位置,从而捕获到丰富的上下文信息。下面,我们将详细解析多头自注意力机制的实现代码。

一、概述

多头自注意力机制的核心思想是将输入序列进行多次线性变换,然后分别计算自注意力得分,最后将所有头的输出进行拼接,并通过一个线性层得到最终的输出。这样做的好处是可以让模型从不同的子空间学习到不同的注意力信息,提高模型的表达能力。

二、代码实现

以下是一个简化版的多头自注意力机制的PyTorch实现,如果有不足之处,感谢指出!!!!:

import torch
import torch.nn as nn
import math

class MultiHeadSelfAttention(nn.Module):
    """
    多头注意力模块,用于实现transformer模型中的注意力机制。
    
    参数:
        model_dim: 模型维度,即输入和输出的向量维度。
        num_heads: 注意力头的数量。
        dropout_rate: Dropout率,防止模型过拟合,默认为0.1。
    """
    def __init__(self, model_dim, num_heads, dropout_rate=0.1):
        super(MultiHeadSelfAttention, self).__init__()
        self.model_dim = model_dim
        self.num_heads = num_heads
        self.head_dim = model_dim // num_heads
        assert model_dim % num_heads == 0, "model_dim 必须能整除注意力头的数量。"
        self.query_projection = nn.Linear(model_dim, model_dim)
        self.key_projection = nn.Linear(model_dim, model_dim)
        self.value_projection = nn.Linear(model_dim, model_dim)
        self.output = nn.Linear(model_dim, model_dim)
        self.dropout = nn.Dropout(dropout_rate)
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, inputs, attention_mask=None, target=None):
        """
        前向传播函数。

        参数:
        - inputs: 输入张量,形状为(batch_size, sequence_length, model_dim)。
        - mask: 掩码张量,形状为(batch_size, sequence_length, sequence_length)。

        返回:
        - output: 输出张量,形状为(batch_size, sequence_length, model_dim)。
        """
        
        batch_size, sequence_length, _ = inputs.shape

        # 对Query、Key和Value进行线性变换
        querys = self.query_projection(inputs)
        keys = self.key_projection(inputs)
        values = self.value_projection(inputs)

        # 进行矩阵分割以实现多头注意力
        querys = querys.reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
        keys = keys.reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
        values = values.reshape(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)

        # 计算scaled dot-product attention,考虑注意力掩码
        attention_scores = torch.matmul(querys, keys.transpose(-2, -1)) / math.sqrt(self.head_dim)
        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2).expand(-1, self.num_heads, sequence_length, -1)
            attention_scores = attention_scores.masked_fill(attention_mask == 0, float('-inf'))
        attention_probs = self.softmax(attention_scores)
        #应用训练阶段的dropout
        if target is not None:
            attention_probs = self.dropout(attention_probs) 
        attention_weights = torch.matmul(attention_probs, values).transpose(1, 2).reshape(batch_size, sequence_length, self.model_dim)
        output = self.output(attention_weights)
        return output, attention_probs
           
# 使用示例:
model_dim = 512
num_heads = 8
mask_attention = torch.IntTensor([[ 1 if i < 8 else 0 for i in range(10) ]])
attention_layer = MultiHeadSelfAttention(model_dim, num_heads)
inputs = torch.randn(1, 10, model_dim)  # 假设我们有一个批次大小为1,序列长度为10,模型维度为512的输入
outputs, attention_weight= attention_layer(inputs, mask_attention)
print(outputs)
print(attention_weight)

  • 19
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

心若成风、

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值