【HuggingFace Transformers】BertEncoder源码解析

1. BertEncoder类 介绍

BertEncoder 类用于实现 BERT 模型的编码器部分。这个编码器由多个堆叠的 BertLayer 组成,用于处理输入的隐藏状态并生成新的隐藏状态。核心功能为:

  • BertLayer 堆叠:编码器的核心是 BertLayer 的堆叠,每一层都处理输入的隐藏状态并生成新的隐藏状态。
  • 梯度检查点:这一功能主要用于减少内存消耗,但需要注意它与缓存功能不兼容。
  • 多种输出选择:可以选择是否输出所有隐藏状态、注意力权重,以及以何种形式返回结果(字典或元组)。
  • 兼容性检查:代码中对梯度检查点与缓存的兼容性进行了检查,并在发现冲突时自动调整。

2. BertEncoder类 源码注释

源码地址:transformers/src/transformers/models/bert/modeling_bert.py

# -*- coding: utf-8 -*-
# @time: 2024/7/12 17:20

import torch

from torch import nn
from transformers import BertLayer
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.utils import logging
from typing import Optional, Tuple, Union

logger = logging.get_logger(__name__)


class BertEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        # 保存配置对象,用于配置编码器的各种参数
        self.config = config

        # 使用 nn.ModuleList 创建一个包含多个 BertLayer 的列表
        # BertLayer 是 BERT 模型的基本组成部分,每个 BertLayer 使用相同的配置
        # config.num_hidden_layers 指定了要创建的 BertLayer 层数,bert-base版是12
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])

        # 初始化梯度检查点标志为 False(默认)
        # 梯度检查点(Gradient Checkpointing)是一种节省内存的技术,用于在计算过程中保存中间状态,以减少内存使用
        # 当需要节省内存时,可以通过设置这个标志来启用梯度检查点功能
        self.gradient_checkpointing = False

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = True,
    ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
        """这个方法执行前向传播操作,计算从输入的隐藏状态到输出隐藏状态的映射过程。"""

        """ 1. 初始化输出变量 """
        # 如果需要输出所有隐藏状态,初始化一个空元组来保存隐藏状态,否则为 None
        all_hidden_states = () if output_hidden_states else None
        # 如果需要输出注意力权重,初始化一个空元组来保存自注意力的权重,否则为 None
        all_self_attentions = () if output_attentions else None
        # 如果需要输出注意力权重且配置中添加了交叉注意力,初始化一个空元组来保存交叉注意力的权重,否则为 None
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

        """ 2. 梯度检查点 """
        # 检查是否启用了梯度检查点(gradient_checkpointing),如果启用了且 use_cache 为 True,则输出警告信息并将 use_cache 设置为 False。这是因为梯度检查点功能与缓存不兼容。
        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
                use_cache = False

        # 初始化 `next_decoder_cache`,如果 `use_cache` 为 True,则为一个空元组,否则为 None
        next_decoder_cache = () if use_cache else None

        """ 3. BertLayer 的堆叠计算 """
        # -------------------------------核心部分:多个BertLayer的堆叠-----------------------------------
        # 遍历所有的 BertLayer,对每一层执行前向传播操作
        for i, layer_module in enumerate(self.layer):
            # 如果需要输出所有隐藏状态,则将当前的隐藏状态添加到 all_hidden_states 中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 获取当前层的头掩码,如果 head_mask 为 None,则 layer_head_mask 也为 None
            layer_head_mask = head_mask[i] if head_mask is not None else None
            # 获取当前层的过去的键值对,如果 past_key_values 为 None,则 past_key_value 也为 None
            past_key_value = past_key_values[i] if past_key_values is not None else None

            # 如果启用了梯度检查点并且当前处于训练模式
            if self.gradient_checkpointing and self.training:
                # 使用自定义的梯度检查点函数执行当前层的前向传播
                layer_outputs = self._gradient_checkpointing_func(
                    layer_module.__call__,
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                )
            else:
                # 否则,正常执行当前层的前向传播
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                )

            # 在每一层计算完成后,更新 hidden_states 为当前层的输出。
            hidden_states = layer_outputs[0]

            # 如果使用缓存,将当前层的缓存结果添加到 next_decoder_cache 中
            if use_cache:
                next_decoder_cache += (layer_outputs[-1],)

            # 如果需要输出注意力权重,将当前层的自注意力权重添加到 all_self_attentions 中
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)
                # 如果配置中启用了交叉注意力,将当前层的交叉注意力权重添加到 all_cross_attentions 中
                if self.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

        # -------------------------------------BertLayer堆叠完成---------------------------------------

        """ 4. 处理输出 """
        # 如果需要输出所有隐藏状态,将最终的 hidden_states 添加到 all_hidden_states 中。
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 如果 return_dict 为 False,则返回一个包含多个值的元组,这些值中只包含非 None 的值
        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,  # 隐藏状态
                    next_decoder_cache,  # 下一步解码器缓存
                    all_hidden_states,  # 所有隐藏状态
                    all_self_attentions,  # 所有自注意力
                    all_cross_attentions,  # 所有交叉注意力
                ]
                if v is not None
            )

        # 如果 return_dict 为 True,则返回一个包含详细信息的模型输出对象
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,  # 最后的隐藏状态
            past_key_values=next_decoder_cache,  # 过去的键值对
            hidden_states=all_hidden_states,  # 所有隐藏状态
            attentions=all_self_attentions,  # 自注意力
            cross_attentions=all_cross_attentions,  # 交叉注意力
        )
  • 5
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CS_木成河

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值