【HuggingFace Transformers】BertLayer源码解析

1. BertLayer类 介绍

BertLayer 类是 BERT 模型的基本构建块之一。它实现了 Transformer 架构中的一个层级结构,在输入的序列上执行自注意力、前馈神经网络等操作,并在必要时结合交叉注意力机制(在解码器中)。其核心功能为:

  • 自注意力机制:在 BERT 模型中,自注意力是核心组件,用于计算序列中各个位置之间的相关性。
  • 交叉注意力机制:在编码器-解码器架构中(如在 is_decoder=True 的情况下),交叉注意力用于处理解码器的隐藏状态与编码器的隐藏状态之间的相关性。
  • 前馈神经网络:经过自注意力和交叉注意力处理后的隐藏状态,进一步通过前馈神经网络进行非线性变换。
  • 缓存机制:在解码器中使用缓存机制,以加速序列生成任务中的推理过程。

2. BertLayer类 源码注释

源码地址:transformers/src/transformers/models/bert/modeling_bert.py

以下代码为 BertLayer 的详细计算过程,包括自注意力交叉注意力(如果存在)、以及前馈神经网络的处理。

最后的输出 outputs 包括以下内容(按顺序):

  • layer_output:前馈神经网络处理后的最终隐藏状态(必有项)。
  • 自注意力权重(如果 output_attentions=True)。
  • 交叉注意力权重(如果 output_attentions=True 且当前层是解码器并启用了交叉注意力)。
  • present_key_value(如果当前层是解码器,则最后包含这个缓存)。
# -*- coding: utf-8 -*-
# @time: 2024/7/12 18:27

import torch

from typing import Optional, Tuple
from torch import nn
from transformers import apply_chunking_to_forward
from transformers.models.bert.modeling_bert import BertAttention, BertIntermediate, BertOutput


class BertLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.chunk_size_feed_forward = config.chunk_size_feed_forward  # 从配置中获取前向传播时的块大小,用于控制前馈神经网络的分块处理(主要用于节省内存)
        self.seq_len_dim = 1  # 设置序列长度维度为 1(通常是 batch_size 维度后面的一维)
        self.attention = BertAttention(config)  # 实例化 BertAttention,用于处理输入的隐藏状态
        self.is_decoder = config.is_decoder  # 用于判断当前层是否是解码器的一部分
        self.add_cross_attention = config.add_cross_attention  # 用于判断是否需要添加交叉注意力层(通常在解码器中用于处理编码器的输出)

        # 如果启用了交叉注意力,并且当前层是解码器,则实例化 BertAttention 作为交叉注意力层
        if self.add_cross_attention:
            if not self.is_decoder:
                raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
            self.crossattention = BertAttention(config, position_embedding_type="absolute")

        # 实例化 BertIntermediate 和 BertOutput,分别用于前馈神经网络的中间层和输出层
        self.intermediate = BertIntermediate(config)  # 初始化中间层
        self.output = BertOutput(config)  # 初始化输出层

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:

        # --------------1. 自注意力机制--------------------------------------------------------------- #
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        # 解码器单向自注意力缓存的键/值对位于位置 1 和 2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None

        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            past_key_value=self_attn_past_key_value,
        )
        # 从自注意力输出中提取注意力输出张量
        attention_output = self_attention_outputs[0]

        # if decoder, the last output is tuple of self-attn cache
        # 如果是解码器,最后的输出是自注意力缓存的键/值对
        if self.is_decoder:
            outputs = self_attention_outputs[1:-1]
            present_key_value = self_attention_outputs[-1]
        else:
            # 如果输出注意力权重,则添加自注意力的输出
            outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights
        # ------------至此,得到了attention_output, outputs, present_key_value(可选)------------------------ #

        # --------------2. 交叉注意力机制(仅在解码器中使用)------------------------------------------------- #
        # 交叉注意力的缓存键/值对,初始设为 None
        cross_attn_present_key_value = None

        # 如果是解码器且提供了编码器隐藏状态,则执行交叉注意力计算
        if self.is_decoder and encoder_hidden_states is not None:
            if not hasattr(self, "crossattention"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
                    " by setting `config.add_cross_attention=True`"
                )

            # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
            # 如果存在 past_key_value,则提取最后两个元素作为交叉注意力的缓存键/值对
            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
            # 调用 crossattention 层进行交叉注意力计算
            cross_attention_outputs = self.crossattention(
                attention_output,
                attention_mask,
                head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                cross_attn_past_key_value,
                output_attentions,
            )
            # 更新注意力输出为交叉注意力的输出
            attention_output = cross_attention_outputs[0]
            # 如果输出注意力权重,则将交叉注意力的中间输出添加到 outputs
            outputs = outputs + cross_attention_outputs[1:-1]  # add cross attentions if we output attention weights

            # add cross-attn cache to positions 3,4 of present_key_value tuple
            # 添加交叉注意力缓存到 present_key_value 元组的第 3 和第 4 位置
            cross_attn_present_key_value = cross_attention_outputs[-1]
            present_key_value = present_key_value + cross_attn_present_key_value  # 交叉注意力的键值对缓存也会被提取并与自注意力的缓存结合在一起
        # ------------至此,得到了更新后的attention_output, outputs, present_key_value--------------------- #

        # ------------------3. 前馈神经网络------------------ #
        # 使用 apply_chunking_to_forward 函数将前馈网络的计算拆分为多个块进行处理
        layer_output = apply_chunking_to_forward(
            self.feed_forward_chunk,  # 前馈网络的计算函数
            self.chunk_size_feed_forward,  # 前馈网络计算的块大小
            self.seq_len_dim,  # 序列长度维度
            attention_output  # 自注意力层的输出
        )

        # 将前馈网络的输出添加到 outputs 元组中
        outputs = (layer_output,) + outputs

        # ------------------4. 返回结果------------------ #
        # if decoder, return the attn key/values as the last output
        # 如果当前层是解码器,则返回包含 present_key_value(注意力键值对缓存)的 outputs
        if self.is_decoder:
            outputs = outputs + (present_key_value,)
        # 否则,返回更新后的隐藏状态和(可选的)注意力权重
        return outputs

    """outputs 包括以下内容(按顺序):
        - layer_output:前馈神经网络处理后的最终隐藏状态(必有项)。
        - 自注意力权重(如果 output_attentions=True)。
        - 交叉注意力权重(如果 output_attentions=True 且当前层是解码器并启用了交叉注意力)。
        - present_key_value(如果当前层是解码器,则最后包含这个缓存)。
    """

    # 前馈神经网络
    def feed_forward_chunk(self, attention_output):
        # 通过前馈网络的中间层处理自注意力输出
        intermediate_output = self.intermediate(attention_output)
        # 将中间层输出和注意力输出传递给前馈网络的输出层
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output
  • 3
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CS_木成河

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值