BERT Pytorch版本 源码解析(二)
四、BertEmbedding 类解析
BertEmbedding部分是组成 BertModel 的第一部分,今天就来讲讲 BertEmbedding 的内部实现细节。
4.1、Embedding 的组成以及设置
def __init__(self, config):
super(BertEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
上面的代码是 BertEmbedding 类的初始化函数,在这块很明显 BertEmbedding 并似乎并没有很特别的地方。总的是设置了三种类型的 embedding,分别是word_embedding,position_embedding,token_type_embedding三种组成。首先,这三种embedding都是用pytorch自带的nn.Embedding 随机生成的,而且它们的向量长度都是 config.hidden_size。之后是一个常见的LayerNorm 以及 Dropout层,这部分就不解释了。
4.2、具体实现
def forward(self, input_ids, token_type_ids=None):
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = words_embeddings + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
首先输入是input_ids或者token_type_ids,input_ids是一个[Batch_size, Seq_length]维度的向量,每一个元素表示对应词表中的index,token_type_ids是对于一个输入存在两个句子的情况,利用 0 和 1 来区分第几个句子的,所以这个部分其实对于大部分任务来说是可以省略的。
然后是关于position_ids的生成,它是自动生成的一个向量,torch.arange(seq_length)是自动生成一个从0开始到seq_length - 1的长度为seq_length的向量。
如果 token_type_ids 是None的情况下则自动生成一个全为0的向量,即所有的输入都是单句的输入。
之后就是利用nn.Embedding来生成三个[Batch_size, Seq_length, Hidden_size]的向量,然后将三个向量进行叠加操作之后进行LayerNorm以及Dropout操作,这就是BertEmbedding的工作原理。