BertIntermediate 和 BertPooler源码解析
1. 介绍
1.1 位置与功能
(1) BertIntermediate
- 位置:位于 BertLayer 的注意力层(BertSelfAttention)和输出层(BertOutput)之间。
- 功能:它执行一个线性变换(通过全连接层)并跟随一个激活函数(通常是 ReLU),为后续层提供更高层次的特征表示。
(2) BertPooler
- 位置:位于整个 BertModel 的最后一层之后,直接处理经过编码的序列表示。
- 功能:从序列的第一个标记(即 [CLS] 标记)提取特征,并通过一个线性变换和 Tanh 激活函数来生成一个全局表示,通常用于分类任务中的最终输出。
1.2 相似点与不同点
(1) 相似点
- 两者都涉及到线性变换,并且都通过激活函数来增强模型的表达能力。
- 都是 BERT 模型中的重要组成部分,从不同的角度和层次上处理输入数据。
(2) 不同点
- 应用层次:
BertIntermediate 作用于每个 Transformer 层,用于构建更深的层级特征。
BertPooler 只在模型的最后一层作用,用于提取全局特征。 - 功能目标:
BertIntermediate 增强中间层的非线性特征,助于后续的自注意力机制。
BertPooler 为分类或回归任务提供一个紧凑的全局特征表示。
2. 源码解析
源码地址:transformers/src/transformers/models/bert/modeling_bert.py
2.1 BertIntermediate 源码解析
# -*- coding: utf-8 -*-
# @time: 2024/7/15 14:17
import torch
from torch import nn
from transformers.activations import ACT2FN
class BertIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
# 全连接层,将 hidden_size 映射到 intermediate_size
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
# 根据 config.hidden_act 定义激活函数
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states) # 线性变换
hidden_states = self.intermediate_act_fn(hidden_states) # 激活函数
return hidden_states
2.2 BertPooler 源码解析
# -*- coding: utf-8 -*-
# @time: 2024/7/19 11:41
import torch
from torch import nn
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) # 全连接层,将 hidden_size 映射回 hidden_size
self.activation = nn.Tanh() # 激活函数为 Tanh 函数
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
# 提取序列中的第一个 token,也就是 [CLS] 的 hidden state
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor) # 线性变换
pooled_output = self.activation(pooled_output) # 激活函数
return pooled_output