【HuggingFace Transformers】BertAttention源码解析

1. BertAttention类 介绍

BERT 模型中,BertAttention 是每一个 Transformer 编码器层的核心部分。BERT 模型通常由多个编码器层叠加而成,每个编码器层都包含一个 BertAttention 层,用于捕获输入序列中各个位置之间的依赖关系。

1.1 BertAttention 的组成

BertAttention 类由两个主要部分组成:

(1) 自注意力层 (BertSelfAttention)

该层执行注意力的计算。它基于输入的隐层状态计算注意力权重,并生成上下文向量。
BertSelfAttention 采用了多头注意力机制,它将输入拆分为多个头(每个头独立计算注意力),然后将结果连接起来,以增加模型的表达能力。在计算注意力时,它会考虑输入序列中的所有位置,这样可以捕获远距离的依赖关系。

(2) 输出层 (BertSelfOutput)

该层对 BertSelfAttention 生成的上下文向量进行进一步的线性变换和归一化处理。它包含一个全连接层、dropout层以及层归一化(layer normalization),用于稳定和优化模型的训练。

1.2 BertAttention 的工作流程

  • 计算自注意力 (self.self)
    在前向传播过程中,首先通过 BertSelfAttention 层对输入隐层状态计算注意力。这个过程涉及对查询、键和值进行线性变换,然后计算点积注意力(dot-product attention)。
  • 处理注意力输出 (self.output)
    自注意力计算的结果会传递给 BertSelfOutput 层,在这里进行线性变换、添加残差连接(residual connection),再进行层归一化。这一步的输出将作为模型的最终注意力输出传递给后续的 Transformer 层。
  • 剪枝(可选
    BertAttention 支持注意力头的剪枝(pruning),这意味着可以在推理或训练过程中动态减少注意力头的数量,以减少计算量和内存占用。

2. BertAttention类 源码解析

源码地址:transformers/src/transformers/models/bert/modeling_bert.py

# -*- coding: utf-8 -*-
# @time: 2024/7/15 14:15

import torch

from torch import nn
from typing import Optional, Tuple
from transformers.models.bert.modeling_bert import BertSelfOutput, BertSelfAttention, BertSdpaSelfAttention
from transformers.pytorch_utils import prune_linear_layer, find_pruneable_heads_and_indices


# 定义一个字典 BERT_SELF_ATTENTION_CLASSES,它将注意力机制的实现方式("eager" 或 "sdpa")映射到相应的类。这允许根据配置选择不同的自注意力实现。
BERT_SELF_ATTENTION_CLASSES = {
    "eager": BertSelfAttention,
    "sdpa": BertSdpaSelfAttention,
}


class BertAttention(nn.Module):
    def __init__(self, config, position_embedding_type=None):
        super().__init__()

        # 根据配置对象 config 中指定的注意力实现类型,从 BERT_SELF_ATTENTION_CLASSES 字典中获取相应的自注意力类
        # 并使用 config 和 position_embedding_type 实例化自注意力层,将其赋值给 self.self
        self.self = BERT_SELF_ATTENTION_CLASSES[config._attn_implementation](
            config, position_embedding_type=position_embedding_type
        )
        # 实例化 BertSelfOutput 类,用于处理自注意力层的输出
        self.output = BertSelfOutput(config)
        # 初始化一个空集合 self.pruned_heads,用于存储被剪枝(移除)的注意力头索引
        self.pruned_heads = set()

    # 用于修剪(剪枝)指定的注意力头
    def prune_heads(self, heads):
        # 如果传入的 heads 列表为空,则直接返回,不进行任何操作
        if len(heads) == 0:
            return
        # 通过 find_pruneable_heads_and_indices 函数找到需要被剪枝的注意力头及其对应的索引
        heads, index = find_pruneable_heads_and_indices(
            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
        )

        # Prune linear layers
        # 使用 prune_linear_layer 对 query, key, value 线性层以及 output.dense 进行剪枝
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # Update hyper params and store pruned heads
        # 更新超参数,减少注意力头的数量和尺寸
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads

        # 将被剪枝的注意力头存储到 pruned_heads 集合中
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # 调用 self-attention 机制计算自注意力输出
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            past_key_value,
            output_attentions,
        )
        # 获取自注意力机制输出的第一个元素并传递给输出层,生成最终的注意力输出
        attention_output = self.output(self_outputs[0], hidden_states)
        # 将注意力输出与其他输出(如果有)结合在一起
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CS_木成河

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值