model_bert.py
BertEmbeddings类
pytorch中nn.Embedding原理及使用
https://www.jianshu.com/p/63e7acc5e890
token embeddings
def __init__(self, config):
super().__init__()
# vocab_size默认为30522;hidden_size默认为768
# 【word_embeddings】词典大小:vocab_size;向量维度:hidden_size
# token embedding层是要将各个词转换成固定维度的向量。在BERT中,每个词会被转换成768维的向量表示。
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids) # wordpiece分词后的token编码
segment embeddings
上图的实现和tokenization_bert.py中的【build_inputs_with_special_tokens函数】【get_special_tokens_mask函数】【create_token_type_ids_from_sequences函数】有关。
# 【token_type_embeddings】词典大小:type_vocab_size;向量维度:hidden_size
# NSP操作:判断属于句子A还是句子B
def __init__(self, config):
super().__init__()
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
if token_type_ids is None: # 如果这个参数是None,则不做NSP下个句子预测处理,句子类型全用0表示
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
token_type_embeddings = self.token_type_embeddings(token_type_ids) # 句子对的分割编码
position embeddings
BERT能够处理最长512个token的输入序列。 论文作者通过让BERT在各个位置上学习一个向量表示来讲序列顺序的信息编码进来。
这意味着Position Embeddings layer 实际上就是一个大小为 (512, 768)
的lookup表,表的第一行是代表第一个序列的第一个位置,第二行代表序列的第二个位置,以此类推。 因此,如果有这样两个句子“Hello
world” 和“Hi there”, “Hello” 和“Hi”会由完全相同的position
embeddings,因为他们都是句子的第一个词。同理,“world” 和“there”也会有相同的position embedding。
引用自 https://www.cnblogs.com/d0main/p/10447853.html
# 【position_embeddings】词典大小:max_position_embeddings;向量维度:hidden_size
# position_embeddings是一个大小为(512, 768)的lookup表,表的第一行是代表第一个序列的第一个位置,以此类推
# 两个句子的同一个位置的位置编码相同
def __init__(self, config):
super().__init__()
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
position_embeddings = self.position_embeddings(position_ids) # 位置编码
BertSelfAttention类
hasattr() 函数用于判断对象是否包含对应的属性。
hasattr(object, name)
参数:object – 对象;name – 字符串,属性名。