从源码层面,深入理解 Bert 框架

65 篇文章 2 订阅
23 篇文章 12 订阅

看了好多 Bert 的介绍文章,少有从源码层面理解Bert模型的,本文章将根据一行一行源码来深入理解Bert模型

BertBertModel 类的源码如下(删除注释)

class BertModel(BertPreTrainedModel):
    def __init__(self, config):
        super(BertModel, self).__init__(config)
        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(input_ids, token_type_ids)
        encoded_layers = self.encoder(embedding_output,
                                      extended_attention_mask,
                                      output_all_encoded_layers=output_all_encoded_layers)
        sequence_output = encoded_layers[-1]
        pooled_output = self.pooler(sequence_output)
        if not output_all_encoded_layers:
            encoded_layers = encoded_layers[-1]
        return encoded_layers, pooled_output

1. __init__函数

首先分析BertBertModel 类的__init__函数(构造函数),__init__函数分别定义了embeddings 、encoder 、pooler三个模块,这三个模块分别是词嵌入模块,encoder模块和分类层模块;然后对模型的参数进行了初始化,通过__init__函数我们也能了解到,Bert主要就是由这三个模块组成,关于这三个模块的详细解析请参考以下文章:

  1. 从源码解析 Bert 的 Embedding 模块​
  2. 从源码解析 Bert 的 BertEncoder 模块
  3. 从源码解析 Bert 的 BertPooler 模块

2. forward函数

forward函数是Bert模型的向前传播函数,作用是将数据从模型的输入传送到输出;
forward函数有四个参数,分别是 input_ids、token_type_ids、attention_mask 和 output_all_encoded_layers,四个参数的解释如下:

参数含义维度
input_idstoken在词汇表中索引组成的数组[batch_size, sequence_length]
token_type_ids用于标识当前token属于哪一个句向量(0属于第一句,1属于第二句)[batch_size, sequence_length]
attention_mask如果输入序列长度小于当前批次中的最大输入序列长度,则使用此掩码,用于指示序列的那些输入需要被Mask,当前位置是小于等于真实长度值为1 大于为0[batch_size,sequence_length]
output_all_encoded_layersTrue:输出全部12层encoder的输出,False:只输出最后一层encoder的值/
  1. 在 forward 函数中,首先对 token_type_idsattention_mask 参数为None值的情况进行了处理;当 token_type_ids 为 None 时,生成一个 [batch_size, sequence_length] 形状的数组赋值给token_type_ids并将 token_type_ids所有位置置为0,表示每个序列中只包含一个句子;当attention_maskNone时,生成一个[batch_size, sequence_length]形状的数组赋值给attention_mask并将attention_mask中的所有值置为1,表示当前序列的所有数据都为有效数据;

  2. 然后 extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 这句代码的含义是将extended_attention_mask 维度连续扩充两次,依次是 [batch_size,sequence_length]->[batch_size,1,sequence_length] ->[batch_size,1,1,sequence_length]
    关于 Pytorch 中 unsqueeze()和 squeeze()函数的详细介绍,请参考这篇博客->#彻底理解# pytorch 中的 squeeze() 和 unsqueeze()函数

  3. extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 (上一行是格式转换,此处略过)这行代码的作用是将 extended_attention_mask 中的 0 变为 -10000,(0表示此位置,不是有效数据,当-10000进入self attention后,权重会变得非常小乃至可以这些token)

  4. embedding_output = self.embeddings(input_ids, token_type_ids) 这行代码的作用是根据 input_ids,token_type_ids 两个矩阵生成 token 对应的 word embedding,( 熟悉 Bert 原理的朋友都知道,除了这两个输入矩阵,还有一个位置编码矩阵,这个矩阵会在 embeddings 模块内部自动生成)

  5. encoded_layers = self.encoder...这句代码的意思是将 embedding 层的输出输入到 encoder 模块 并将 encoder 模块的输出赋值给 encoded_layers,其中encoded_layers是一个长度为12的数组,保存着12层encoder每层对应的输出

  6. sequence_output = encoded_layers[-1] pooled_output = self.pooler(sequence_output)这两句代码的作用是生成 Bert 的pooled_output输出;做法是:将encoder最后一层的输出赋值给 Bert 的 pooler 模块(输入768,输出768,tanh激活函数的全连接层)并将结果返回;虽然pool模块的输入为768,但只使用第一个 token([CLS])的对应的输入,因此 pooled_output 输出一般用于解决 语句级别的任务。

  7. 根据 output_all_encoded_layers参数的值返回所有12层encoder的输出或仅输出最后一层encoder的输出;同时输出 pooled_output

想深入了解Bert模型原理的朋友可以阅读这5篇文章:

  1. #深入理解# Bert框架原理
  2. #由浅入深# 从 Seq2seq 到 Transformer
  3. #最全面# 使用 Bert 解决下游 NLP 实际任务
  4. 从NLP中的标记算法(tokenization)到bert中的WordPiece
  5. #彻底理解# NLP中的word2vec

接下来将进一步分析 Bert 模型中的 embeddings、encoder、pooler 等模块,链接如下:

  1. 从源码解析 Bert 的 Embedding 模块​
  2. 从源码解析 Bert 的 BertEncoder 模块
  3. 从源码解析 Bert 的 BertPooler 模块
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

energy_百分百

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值