看了好多 Bert 的介绍文章,少有从源码层面理解Bert模型的,本文章将根据一行一行源码来深入理解Bert模型
BertBertModel 类的源码如下(删除注释)
class BertModel(BertPreTrainedModel):
def __init__(self, config):
super(BertModel, self).__init__(config)
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
embedding_output = self.embeddings(input_ids, token_type_ids)
encoded_layers = self.encoder(embedding_output,
extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers)
sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output)
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
return encoded_layers, pooled_output
1. __init__函数
首先分析BertBertModel 类的__init__函数(构造函数),__init__函数分别定义了embeddings 、encoder 、pooler三个模块,这三个模块分别是词嵌入模块,encoder模块和分类层模块;然后对模型的参数进行了初始化,通过__init__函数我们也能了解到,Bert主要就是由这三个模块组成,关于这三个模块的详细解析请参考以下文章:
2. forward函数
forward函数是Bert模型的向前传播函数,作用是将数据从模型的输入传送到输出;
forward函数有四个参数,分别是 input_ids、token_type_ids、attention_mask 和 output_all_encoded_layers,四个参数的解释如下:
参数 | 含义 | 维度 |
---|---|---|
input_ids | token在词汇表中索引组成的数组 | [batch_size, sequence_length] |
token_type_ids | 用于标识当前token属于哪一个句向量(0属于第一句,1属于第二句) | [batch_size, sequence_length] |
attention_mask | 如果输入序列长度小于当前批次中的最大输入序列长度,则使用此掩码,用于指示序列的那些输入需要被Mask,当前位置是小于等于真实长度值为1 大于为0 | [batch_size,sequence_length] |
output_all_encoded_layers | True:输出全部12层encoder的输出,False:只输出最后一层encoder的值 | / |
-
在 forward 函数中,首先对
token_type_ids
和attention_mask
参数为None值的情况进行了处理;当token_type_ids
为 None 时,生成一个[batch_size, sequence_length]
形状的数组赋值给token_type_ids
并将token_type_ids
所有位置置为0,表示每个序列中只包含一个句子;当attention_mask
为None
时,生成一个[batch_size, sequence_length]
形状的数组赋值给attention_mask
并将attention_mask
中的所有值置为1,表示当前序列的所有数据都为有效数据; -
然后
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
这句代码的含义是将extended_attention_mask 维度连续扩充两次,依次是 [batch_size,sequence_length]->[batch_size,1,sequence_length] ->[batch_size,1,1,sequence_length]
关于 Pytorch 中 unsqueeze()和 squeeze()函数的详细介绍,请参考这篇博客->#彻底理解# pytorch 中的 squeeze() 和 unsqueeze()函数 -
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
(上一行是格式转换,此处略过)这行代码的作用是将 extended_attention_mask 中的 0 变为 -10000,(0表示此位置,不是有效数据,当-10000进入self attention后,权重会变得非常小乃至可以这些token) -
embedding_output = self.embeddings(input_ids, token_type_ids)
这行代码的作用是根据 input_ids,token_type_ids 两个矩阵生成 token 对应的 word embedding,( 熟悉 Bert 原理的朋友都知道,除了这两个输入矩阵,还有一个位置编码矩阵,这个矩阵会在 embeddings 模块内部自动生成) -
encoded_layers = self.encoder...
这句代码的意思是将 embedding 层的输出输入到encoder
模块 并将 encoder 模块的输出赋值给encoded_layers
,其中encoded_layers是一个长度为12的数组,保存着12层encoder每层对应的输出 -
sequence_output = encoded_layers[-1] pooled_output = self.pooler(sequence_output)
这两句代码的作用是生成 Bert 的pooled_output
输出;做法是:将encoder
最后一层的输出赋值给 Bert 的pooler
模块(输入768,输出768,tanh激活函数的全连接层)并将结果返回;虽然pool模块的输入为768,但只使用第一个 token([CLS])的对应的输入,因此pooled_output
输出一般用于解决 语句级别的任务。 -
根据
output_all_encoded_layers
参数的值返回所有12层encoder
的输出或仅输出最后一层encoder
的输出;同时输出pooled_output
想深入了解Bert模型原理的朋友可以阅读这5篇文章:
- #深入理解# Bert框架原理
- #由浅入深# 从 Seq2seq 到 Transformer
- #最全面# 使用 Bert 解决下游 NLP 实际任务
- 从NLP中的标记算法(tokenization)到bert中的WordPiece
- #彻底理解# NLP中的word2vec
接下来将进一步分析 Bert 模型中的 embeddings、encoder、pooler 等模块,链接如下: