1. 类定义与作用
• 功能:BERT模型的前馈神经网络中间层(Feed-Forward Network, FFN),属于Transformer层的核心组件之一。
• 位置:位于自注意力层(BertSelfAttention
)之后,负责对注意力输出进行非线性变换和维度扩展。
• 典型配置(以BERT-base为例):
• 输入维度:hidden_size = 768
• 中间维度:intermediate_size = 3072
(扩展4倍)
• 激活函数:hidden_act = "gelu"
(高斯误差线性单元)
2. 初始化方法 (__init__
)
def __init__(self, config):
super(BertIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) # 线性变换层
# 激活函数处理逻辑
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
self.intermediate_act_fn = AAA[config.hidden_act] # 从字符串映射到函数
else:
self.intermediate_act_fn = config.hidden_act # 直接使用函数
关键组件
-
线性变换层 (
self.dense
):
• 输入维度:config.hidden_size
(如768)
• 输出维度:config.intermediate_size
(如3072)
• 作用:将自注意力输出的隐藏状态从hidden_size
投影到更大的中间维度。 -
激活函数选择逻辑:
• 输入为字符串(如"gelu"
,"relu"
):通过预定义的字典AAA
将字符串映射到对应的PyTorch激活函数。
• 输入为函数(如torch.nn.functional.gelu
):直接使用该函数。
• Python 2兼容性:额外检查unicode
类型(Python 2中字符串的Unicode表示)。
示例激活函数映射 (AAA
):
AAA = {
"gelu": torch.nn.functional.gelu,
"relu": torch.nn.functional.relu,
"tanh": torch.tanh,
# 其他自定义激活函数...
}
3. 前向传播 (forward
方法)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states) # [batch, seq_len, intermediate_size]
hidden_states = self.intermediate_act_fn(hidden_states) # 应用激活函数
return hidden_states
步骤解析
-
线性变换:
• 输入形状:[batch_size, sequence_length, hidden_size]
• 输出形状:[batch_size, sequence_length, intermediate_size]
• 示例:[2, 128, 768] → [2, 128, 3072]
-
激活函数:
• 对每个位置的特征向量独立应用非线性激活(如GELU、ReLU)。
• GELU的特点:通过门控机制保留部分原始信息,比ReLU更平滑,适合自然语言任务。
4. 与后续层的衔接
• 下游处理:BertIntermediate
的输出会传递给BertOutput
层(另一个线性层 + 层归一化 + Dropout),将维度从intermediate_size
恢复为hidden_size
。
• 完整FFN流程:
BertSelfAttention → BertIntermediate → BertOutput
5. 总结
组件 | 作用 |
---|---|
线性变换 (self.dense ) | 扩展特征维度,增强模型表达能力。 |
激活函数 (intermediate_act_fn ) | 引入非线性,使模型能拟合复杂函数关系。 |
6. 参数示例
假设 config
包含以下配置:
config.hidden_size = 768
config.intermediate_size = 3072
config.hidden_act = "gelu"
• 初始化结果:
• self.dense
: nn.Linear(768, 3072)
• self.intermediate_act_fn
: torch.nn.functional.gelu
Ending
BertIntermediate
是BERT模型中前馈神经网络的核心组件,通过线性变换扩展维度并应用非线性激活函数,为模型提供深层特征提取能力。其设计遵循Transformer架构的标准模式(扩展→激活→压缩),与自注意力机制协同工作,共同捕捉序列数据的复杂依赖关系。