【HuggingFace Transformers】BertIntermediate 和 BertPooler源码解析

1. 介绍

1.1 位置与功能

(1) BertIntermediate

  • 位置:位于 BertLayer 的注意力层(BertSelfAttention)和输出层(BertOutput)之间。
  • 功能:它执行一个线性变换(通过全连接层)并跟随一个激活函数(通常是 ReLU),为后续层提供更高层次的特征表示。

(2) BertPooler

  • 位置:位于整个 BertModel 的最后一层之后,直接处理经过编码的序列表示。
  • 功能:从序列的第一个标记(即 [CLS] 标记)提取特征,并通过一个线性变换和 Tanh 激活函数来生成一个全局表示,通常用于分类任务中的最终输出。

1.2 相似点与不同点

(1) 相似点

  • 两者都涉及到线性变换,并且都通过激活函数来增强模型的表达能力。
  • 都是 BERT 模型中的重要组成部分,从不同的角度和层次上处理输入数据。

(2) 不同点

  • 应用层次:
    BertIntermediate 作用于每个 Transformer 层,用于构建更深的层级特征。
    BertPooler 只在模型的最后一层作用,用于提取全局特征。
  • 功能目标:
    BertIntermediate 增强中间层的非线性特征,助于后续的自注意力机制。
    BertPooler 为分类或回归任务提供一个紧凑的全局特征表示。

2. 源码解析

源码地址:transformers/src/transformers/models/bert/modeling_bert.py

2.1 BertIntermediate 源码解析

# -*- coding: utf-8 -*-
# @time: 2024/7/15 14:17
import torch

from torch import nn
from transformers.activations import ACT2FN


class BertIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 全连接层,将 hidden_size 映射到 intermediate_size
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)

        # 根据 config.hidden_act 定义激活函数
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)  # 线性变换
        hidden_states = self.intermediate_act_fn(hidden_states)  # 激活函数
        return hidden_states

2.2 BertPooler 源码解析

# -*- coding: utf-8 -*-
# @time: 2024/7/19 11:41

import torch

from torch import nn


class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)  # 全连接层,将 hidden_size 映射回 hidden_size
        self.activation = nn.Tanh()  # 激活函数为 Tanh 函数

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        # 提取序列中的第一个 token,也就是 [CLS] 的 hidden state
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)  # 线性变换
        pooled_output = self.activation(pooled_output)  # 激活函数
        return pooled_output
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CS_木成河

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值