【HuggingFace Transformers】BertSdpaSelfAttention源码解析

1. BertSdpaSelfAttention类 介绍

BertSdpaSelfAttention类是 BERT 模型自注意力层的实现,继承 BertSelfAttention 类。BertSdpaSelfAttention 模块是在 Hugging Facetransformers 库的 4.31.0 版本后引入的。这一模块是为了增强 BERT 模型中自注意力机制的效率和性能,利用了 PyTorch 中的 scaled_dot_product_attention 函数。

2. BertSdpaSelfAttention类 源码解析

源码地址:transformers/src/transformers/models/bert/modeling_bert.py

# -*- coding: utf-8 -*-
# @time: 2024/7/15 14:30

import torch

from typing import Optional, Tuple
from packaging import version
from transformers.models.bert.modeling_bert import BertSelfAttention
from transformers.utils import get_torch_version, logging

logger = logging.get_logger(__name__)


class BertSdpaSelfAttention(BertSelfAttention):
    def __init__(self, config, position_embedding_type=None):
        super().__init__(config, position_embedding_type=position_embedding_type)  # 初始化父类,继承其配置和属性
        self.dropout_prob = config.attention_probs_dropout_prob  # 存储注意力机制中的dropout概率
        # 检查当前的PyTorch版本,决定是否需要连续的qkv输入,低于2.2.0版本时设为True
        self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")

    # Adapted from BertSelfAttention
    def forward(
            self,
            hidden_states: torch.Tensor,  # 输入的隐藏状态张量
            attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码,可选
            head_mask: Optional[torch.FloatTensor] = None,  # 头部掩码,可选
            encoder_hidden_states: Optional[torch.FloatTensor] = None,  # 编码器的隐藏状态,仅在交叉注意力中使用
            encoder_attention_mask: Optional[torch.FloatTensor] = None,  # 编码器的注意力掩码,仅在交叉注意力中使用
            past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,  # 过去的键和值,用于缓存
            output_attentions: Optional[bool] = False,  # 是否输出注意力权重
    ) -> Tuple[torch.Tensor]:

        # 1. 如果位置嵌入类型不是 "absolute" 或者需要输出注意力权重或使用头部掩码,则调用BertSelfAttention的forward方法来处理
        if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
            # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
            # 记录一个警告,提示用户在未来版本中需要手动指定注意力实现
            logger.warning_once(
                "BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
                "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
                "the manual attention implementation, but specifying the manual implementation will be required from "
                "Transformers version v5.0.0 onwards. This warning can be removed using the argument "
                '`attn_implementation="eager"` when loading the model.'
            )
            return super().forward(
                hidden_states,
                attention_mask,
                head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                past_key_value,
                output_attentions,
            )

        # -------------------- 2. 位置嵌入类型position_embedding_type是 "absolute" -------------------------
        # ----------------- 2.1 获取输入的批次大小(bsz)、目标序列长度(tgt_len), 后面会用于shape的调整-----------------
        bsz, tgt_len, _ = hidden_states.size()

        # ---------------- 2.2 获取query_layer, key_layer, value_layer, attention_mask, is_causal, 用于注意力的计算-------
        # 将hidden_states投影到查询向量(query_layer),并使用transpose_for_scores方法调整其形状
        query_layer = self.transpose_for_scores(self.query(hidden_states))

        # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
        # mask needs to be such that the encoder's padding tokens are not attended to.
        """ 如果这是一个跨注意力模块实例化的情况,键和值来自编码器;注意力掩码需要确保编码器中的填充标记不被关注。 """
        # 判断是否是交叉注意力,并为current_states和attention_mask赋值
        is_cross_attention = encoder_hidden_states is not None
        current_states = encoder_hidden_states if is_cross_attention else hidden_states
        attention_mask = encoder_attention_mask if is_cross_attention else attention_mask

        # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
        # 如果是交叉注意力且有past_key_value,并且它的序列长度与current_states一致,直接使用缓存的键和值
        if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
            key_layer, value_layer = past_key_value
        else:
            # 否则,计算新的键值对,并在非交叉注意力情况下,将它们与过去的键值对拼接
            key_layer = self.transpose_for_scores(self.key(current_states))
            value_layer = self.transpose_for_scores(self.value(current_states))
            if past_key_value is not None and not is_cross_attention:
                key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
                value_layer = torch.cat([past_key_value[1], value_layer], dim=2)

        # 如果是解码器,则缓存当前的键和值,以便在后续步骤中使用
        if self.is_decoder:
            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
            # Further calls to cross_attention layer can then reuse all cross-attention
            # key/value_states (first "if" case)
            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
            # all previous decoder key/value_states. Further calls to uni-directional self-attention
            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
            # if encoder bi-directional self-attention `past_key_value` is always `None`
            """
            如果是交叉注意力,将所有交叉注意力的 key/value 状态保存为一个包含两个 torch.Tensor 的元组。
            后续对交叉注意力层的调用可以重用所有交叉注意力的 key/value 状态(即第一个 "if" 情况)。
            如果是单向自注意力(解码器),则保存所有先前解码器的 key/value 状态为一个包含两个 torch.Tensor 的元组。
            后续对单向自注意力层的调用可以将先前解码器的 key/value 状态与当前投影的 key/value 状态拼接起来(即第三个 "elif" 情况)。
            如果是编码器的双向自注意力,则 `past_key_value` 始终为 `None`。
            """
            past_key_value = (key_layer, value_layer)

        # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
        # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
        # Reference: https://github.com/pytorch/pytorch/issues/112577
        """
        在 torch==2.1.2 中,当使用非连续的输入和自定义的注意力掩码时,带有内存高效后端的 SDPA(缩放点积注意力)是有问题的,
        因此我们需要在这里调用 `.contiguous()` 方法来确保输入是连续的。这个问题在 torch==2.2.0 中已被修复。'
        参考:https://github.com/pytorch/pytorch/issues/112577
        """
        # 如果 PyTorch 版本低于 2.2.0 且设备类型为 CUDA 且 attention_mask 不为空,则确保 qkv 输入是连续的
        if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
            query_layer = query_layer.contiguous()
            key_layer = key_layer.contiguous()
            value_layer = value_layer.contiguous()

        # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal
        # mask in case tgt_len == 1.
        """
        需要保证 tgt_len > 1,以匹配 AttentionMaskConverter.to_causal_4d 的行为,
        因为当 tgt_len == 1 时,它不会创建因果掩码。
        """
        # 如果是解码器且没有注意力掩码且目标序列长度大于1,则启用因果注意力
        is_causal = self.is_decoder and attention_mask is None and tgt_len > 1

        # ----------- 2.3 使用 torch.nn.functional.scaled_dot_product_attention 计算注意力 -------------
        # 使用 PyTorch 的 scaled_dot_product_attention 函数计算注意力输出,传入查询、键、值、注意力掩码和dropout概率
        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_layer,
            key_layer,
            value_layer,
            attn_mask=attention_mask,
            dropout_p=self.dropout_prob if self.training else 0.0,
            is_causal=is_causal,
        )
        # ----------- 2.4 调整 attn_output 的形状, 用于返回计算后的注意力输出, 以及在解码器模式下缓存的键值 -----------------
        # 将注意力输出张量转置并调整形状以匹配原始输入的形状
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)

        # 返回包含注意力输出的元组,如果是解码器,还会返回缓存的键和值
        outputs = (attn_output,)
        if self.is_decoder:
            outputs = outputs + (past_key_value,)
        return outputs
  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CS_木成河

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值