BertAttention源码解析
1. BertAttention类 介绍
在 BERT 模型中,BertAttention 是每一个 Transformer 编码器层的核心部分。BERT 模型通常由多个编码器层叠加而成,每个编码器层都包含一个 BertAttention 层,用于捕获输入序列中各个位置之间的依赖关系。
1.1 BertAttention 的组成
BertAttention 类由两个主要部分组成:
(1) 自注意力层 (BertSelfAttention)
该层执行注意力的计算。它基于输入的隐层状态计算注意力权重,并生成上下文向量。
BertSelfAttention 采用了多头注意力机制,它将输入拆分为多个头(每个头独立计算注意力),然后将结果连接起来,以增加模型的表达能力。在计算注意力时,它会考虑输入序列中的所有位置,这样可以捕获远距离的依赖关系。
(2) 输出层 (BertSelfOutput)
该层对 BertSelfAttention 生成的上下文向量进行进一步的线性变换和归一化处理。它包含一个全连接层、dropout层以及层归一化(layer normalization),用于稳定和优化模型的训练。
1.2 BertAttention 的工作流程
- 计算自注意力 (self.self)
在前向传播过程中,首先通过 BertSelfAttention 层对输入隐层状态计算注意力。这个过程涉及对查询、键和值进行线性变换,然后计算点积注意力(dot-product attention)。 - 处理注意力输出 (self.output)
自注意力计算的结果会传递给 BertSelfOutput 层,在这里进行线性变换、添加残差连接(residual connection),再进行层归一化。这一步的输出将作为模型的最终注意力输出传递给后续的 Transformer 层。 - 剪枝(可选)
BertAttention 支持注意力头的剪枝(pruning),这意味着可以在推理或训练过程中动态减少注意力头的数量,以减少计算量和内存占用。
2. BertAttention类 源码解析
源码地址:transformers/src/transformers/models/bert/modeling_bert.py
# -*- coding: utf-8 -*-
# @time: 2024/7/15 14:15
import torch
from torch import nn
from typing import Optional, Tuple
from transformers.models.bert.modeling_bert import BertSelfOutput, BertSelfAttention, BertSdpaSelfAttention
from transformers.pytorch_utils import prune_linear_layer, find_pruneable_heads_and_indices
# 定义一个字典 BERT_SELF_ATTENTION_CLASSES,它将注意力机制的实现方式("eager" 或 "sdpa")映射到相应的类。这允许根据配置选择不同的自注意力实现。
BERT_SELF_ATTENTION_CLASSES = {
"eager": BertSelfAttention,
"sdpa": BertSdpaSelfAttention,
}
class BertAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
# 根据配置对象 config 中指定的注意力实现类型,从 BERT_SELF_ATTENTION_CLASSES 字典中获取相应的自注意力类
# 并使用 config 和 position_embedding_type 实例化自注意力层,将其赋值给 self.self
self.self = BERT_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
# 实例化 BertSelfOutput 类,用于处理自注意力层的输出
self.output = BertSelfOutput(config)
# 初始化一个空集合 self.pruned_heads,用于存储被剪枝(移除)的注意力头索引
self.pruned_heads = set()
# 用于修剪(剪枝)指定的注意力头
def prune_heads(self, heads):
# 如果传入的 heads 列表为空,则直接返回,不进行任何操作
if len(heads) == 0:
return
# 通过 find_pruneable_heads_and_indices 函数找到需要被剪枝的注意力头及其对应的索引
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
# Prune linear layers
# 使用 prune_linear_layer 对 query, key, value 线性层以及 output.dense 进行剪枝
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
# 更新超参数,减少注意力头的数量和尺寸
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
# 将被剪枝的注意力头存储到 pruned_heads 集合中
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
# 调用 self-attention 机制计算自注意力输出
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
# 获取自注意力机制输出的第一个元素并传递给输出层,生成最终的注意力输出
attention_output = self.output(self_outputs[0], hidden_states)
# 将注意力输出与其他输出(如果有)结合在一起
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs