BertEmbeddings类源码解析
1. BertEmbeddings介绍
BertEmbeddings 类是 BERT 模型的一个重要组成部分。这个类将词嵌入(Token Embeddings) 、位置嵌入(Position Embeddings) 和 标记类型嵌入(Segment Embeddings) 组合起来,为每个输入token生成最终的嵌入表示,为后续的编码器层提供输入。
图片来源: BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
2. BertEmbeddings类 流程
源码地址:transformers/src/transformers/models/bert/modeling_bert.py
> forward()方法的主要流程 <
1) input_shape
参数:通过input_ids
或者 inputs_embeds
可以得到
input_shape: [batch_size, seq_length]
2)position_ids
参数:如果传入的参数中包含position_ids
, 不用再考虑;否则, 从缓冲区中获取相应的长度
例如:text = "我爱学习", position_ids = tensor([[0, 1, 2, 3, 4, 5]])
3)token_type_ids
参数:如果传入的参数中包含token_type_ids
, 不用再考虑;否则, 从缓冲区中获取相应的长度
例如:text = "我爱学习", token_type_ids = tensor([[0, 0, 0, 0, 0, 0]])
4)inputs_embeds
参数: 一般是根据input_ids
获取, 即:
self.word_embeddings = torch.nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id)
inputs_embeds = self.word_embeddings(input_ids)
5)token_type_embeddings
参数:
self.token_type_embeddings = torch.nn.Embedding(type_vocab_size, hidden_size)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
6)position_embeddings
参数:取决于position_embedding_type == “absolute”
- True: position_embeddings = self.position_embeddings(position_ids) # 使用绝对位置编码
- False: # 不使用绝对编码
7)embeddings
最终结果:
# 使用绝对位置编码:embeddings = inputs_embeds + token_type_embeddings + position_embeddings
# 不使用绝对位置编码:embeddings = inputs_embeds + token_type_embeddings
embeddings = inputs_embeds + token_type_embeddings (+ position_embeddings)
embeddings = self.dropout(self.LayerNorm(embeddings))
3. BertEmbeddings类 代码注释
# -*- coding: utf-8 -*-
# @time: 2024/7/12 11:35
import torch
from torch import nn
from typing import Optional
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config):
super().__init__()
# 词嵌入层:用于将词汇表中的词映射到一个固定大小的向量空间, 使用词汇表大小(config.vocab_size)和隐藏层大小(config.hidden_size)来初始化嵌入层。
# padding_idx 指定了填充标记的索引,以便在计算嵌入时忽略这些位置。
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
# 位置嵌入层:初始化一个嵌入层,表示输入序列中每个位置的嵌入。位置嵌入的数量由 config.max_position_embeddings 指定。
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
# 标记类型嵌入层:用于区分不同类型的标记(如句子A和句子B),通常用于处理句子对的任务。
# 嵌入层的大小由 config.type_vocab_size 和 config.hidden_size 指定。
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load any TensorFlow checkpoint file
# 层归一化层:对隐藏层的输出进行归一化,以提高训练的稳定性。eps 是防止归一化过程中出现除零错误的一个小常数。
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# Dropout层:用于在训练时防止过拟合。随机将一些神经元的输出置为0,以防止过拟合。Dropout概率由 config.hidden_dropout_prob 指定。
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
# 位置嵌入类型:从配置中获取 position_embedding_type 属性,如果没有设置,默认为 "absolute"。
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
# 注册一个缓冲区:position_ids: 用于存储位置ID,从 0 到 max_position_embeddings,在内存中是连续的,并在序列化时导出。
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False)
# 注册另一个缓冲区:token_type_ids: 用于存储标记类型ID,初始值全部为0,大小与 position_ids 相同。
self.register_buffer("token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
# input_shape: [batch_size, seq_length]
seq_length = input_shape[1]
# 如果 position_ids 为空,则从 position_ids 缓冲区中获取相应长度的 position_ids
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length: seq_length + past_key_values_length]
# position_ids: tensor([[0, 1, 2, 3, 4, 5]])
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
# issue #5664
# 将 token_type_ids 设置为构造函数中注册的缓冲区(全部为零),这通常在自动生成时发生。
# 注册缓冲区在用户不传递 token_type_ids 时帮助跟踪模型,解决了 issue #5664
if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
# token_type_ids: tensor([[0, 0, 0, 0, 0, 0]])
# 如果 inputs_embeds 为空,则通过 input_ids 获取词嵌入
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
# 获取标记类型嵌入
token_type_embeddings = self.token_type_embeddings(token_type_ids)
# 将词嵌入和标记类型嵌入相加
embeddings = inputs_embeds + token_type_embeddings
# 如果位置嵌入类型为 "absolute",则加上位置嵌入
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings) # 进行层归一化
embeddings = self.dropout(embeddings) # 应用 dropout
return embeddings
4. BertEmbeddings类 测试
# -*- coding: utf-8 -*-
# @time: 2024/7/12 16:21
import torch
from transformers import BertTokenizer, BertConfig
from BertEmbeddings import BertEmbeddings
# 加载预训练的 BERT 分词器
model_name = 'google-bert/bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
# 示例文本,包含两句话
text = [("my dog is cute", "he likes play ing")]
inputs = tokenizer(text, truncation=True, padding=True, return_tensors='pt')
input_ids = inputs["input_ids"]
print(input_ids)
token_type_ids = inputs["token_type_ids"]
print(token_type_ids)
configuration = BertConfig()
bert_embeddings = BertEmbeddings(configuration)
with torch.no_grad():
embeddings = bert_embeddings(input_ids=input_ids, token_type_ids=token_type_ids)
print(embeddings)
print(embeddings.shape)
输出:
tensor([[ 101, 2026, 3899, 2003, 10140, 102, 2002, 7777, 2377, 13749,
102]])
tensor([[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]])
tensor([[[ 1.1078, 1.1187, 0.8171, ..., -1.7894, -0.3242, -0.3227],
[-1.1724, 0.6758, -0.7380, ..., -2.0156, 0.6501, 0.0658],
[ 2.4001, 1.4040, 0.3180, ..., 0.2173, 0.3123, -0.3133],
...,
[-0.2485, 0.0995, 1.1544, ..., 0.6161, 0.6230, -0.8850],
[ 0.2799, -1.1622, 0.0000, ..., 1.9079, 0.0000, -0.8867],
[-0.1645, -0.0000, -0.3955, ..., 2.5497, -1.2822, -2.1249]]])
torch.Size([1, 11, 768])