bert在预训练时的两个下游任务详解

16 篇文章 1 订阅
11 篇文章 1 订阅

众所周知,bert预训练有加入了两个下游任务进行训练,分别是next sentence prediction和mask prediction。

  • next sentence prediction:输入[CLS]a[SEP]b[SEP],预测b是否为a的下一句,即二分类问题;
  • mask prediction:输入[CLS]我 mask 中 mask 天 安 门[SEP],预测句子的mask,多分类问题

一直很困惑bert在经过transformer的encoder之后的两个任务是怎么做的。比如:

  • 这两个任务的loss函数是什么?
  • 拿到encoder输出做了什么?

很多人都是这么回答的:

  • bert的loss就是下一句预测的分类loss和mask词的预测loss
  • 拿encoder输出直接算loss啊

如果你继续问:

  • 那他们的loss都分别怎么算的?

相信很多人都讲不出来。下面就详细在这阐述下。相信你会对bert更了解的。

源码解读:

BertModel类中的forward部分代码如下
在这里插入图片描述
encoded_layers表示encoder所有层隐层状态输出;sequence_output表示最后一层。首先经过一个self.pooler变换:
在这里插入图片描述
在这里插入图片描述
注意看:将sequence_output输入BertPooler这个类之后,取了句子第一个token(first_token_tensor,即CLS)的向量,然后将first_token_tensor输入self.dense。从该类的初始化中可知,self.dense是一个线性层,参数为768 * 768。经过该线性层之后,又过了一个Tanh激活函数。然后得到pooled_output(1 * 768 )。

经过BertModel类之后,返回了两个输出sequence_output(len * 768)和pooled_output(1 * 768 )。
在这里插入图片描述
然后经过self.cls,即类BertPreTrainingHeads。
在这里插入图片描述

经过BertPreTrainingHeads时,先进过self.predictions。 即类BertLMPredictionHead:
在这里插入图片描述
在BertLMPredictionHead里面又经过self.transformer和self.decoder
在这里插入图片描述
self.transformer中经历了self.dense(768 * 768线性层),self.transform_act_fn激活函数(根据bert config,这里应该为gelu),self.LayerNorm(归一化处理)

self.decoder中经历了一个线性层(768 * vocab_size),因为
bert_model_embedding_weights共享了词向量参数,因此该线性层为768 * vocab_size。

终于走到底了,那么现在开始往回走:重新给梳理下,从下往上走:
2个输入:
1、sequence_output(len * 768 ):最后一层隐层状态。依次经过self.transformer(768 * 768线性层) -> self.transform_act_fn(tanh激活函数) -> self.LayerNorm(归一化) -> self.decoder(768 * vocab_size,共享词向量参数)
经过这一系列之后,我们就拿到了prediction_scores(len * vocab_size)然后我们就可以直接算loss了,采用的是CrossEntropyLoss交叉熵函数。

2、pooled_output(1 * 768 ):CLS隐层状态。依次经过self.seq_relationship(768 * 2线性层)
经过这一系列之后,我们就拿到了seq_relationship_score(1 * 2 ),二分类。然后我们就可以直接算loss了,采用的同样是CrossEntropyLoss交叉熵函数。

最后,loss = 上述两部分的loss之和
在这里插入图片描述
因此,这也就解释了,为什么bert的参数中包含了以下几个参数了

  • bert.pooler.dense.weight
  • bert.pooler.dense.bias
  • classifier.weight
  • classifier.bias

如果不正确之处,请不吝指正,感激不尽,谢谢!

下面是所有bert预训练参数:

bert.embeddings.word_embeddings.weight
bert.embeddings.position_embeddings.weight
bert.embeddings.token_type_embeddings.weight
bert.embeddings.LayerNorm.weight
bert.embeddings.LayerNorm.bias
bert.encoder.layer.0.attention.self.query.weight
bert.encoder.layer.0.attention.self.query.bias
bert.encoder.layer.0.attention.self.key.weight
bert.encoder.layer.0.attention.self.key.bias
bert.encoder.layer.0.attention.self.value.weight
bert.encoder.layer.0.attention.self.value.bias
bert.encoder.layer.0.attention.output.dense.weight
bert.encoder.layer.0.attention.output.dense.bias
bert.encoder.layer.0.attention.output.LayerNorm.weight
bert.encoder.layer.0.attention.output.LayerNorm.bias
bert.encoder.layer.0.intermediate.dense.weight
bert.encoder.layer.0.intermediate.dense.bias
bert.encoder.layer.0.output.dense.weight
bert.encoder.layer.0.output.dense.bias
bert.encoder.layer.0.output.LayerNorm.weight
bert.encoder.layer.0.output.LayerNorm.bias
bert.encoder.layer.1.attention.self.query.weight
bert.encoder.layer.1.attention.self.query.bias
bert.encoder.layer.1.attention.self.key.weight
bert.encoder.layer.1.attention.self.key.bias
bert.encoder.layer.1.attention.self.value.weight
bert.encoder.layer.1.attention.self.value.bias
bert.encoder.layer.1.attention.output.dense.weight
bert.encoder.layer.1.attention.output.dense.bias
bert.encoder.layer.1.attention.output.LayerNorm.weight
bert.encoder.layer.1.attention.output.LayerNorm.bias
bert.encoder.layer.1.intermediate.dense.weight
bert.encoder.layer.1.intermediate.dense.bias
bert.encoder.layer.1.output.dense.weight
bert.encoder.layer.1.output.dense.bias
bert.encoder.layer.1.output.LayerNorm.weight
bert.encoder.layer.1.output.LayerNorm.bias
bert.encoder.layer.2.attention.self.query.weight
bert.encoder.layer.2.attention.self.query.bias
bert.encoder.layer.2.attention.self.key.weight
bert.encoder.layer.2.attention.self.key.bias
bert.encoder.layer.2.attention.self.value.weight
bert.encoder.layer.2.attention.self.value.bias
bert.encoder.layer.2.attention.output.dense.weight
bert.encoder.layer.2.attention.output.dense.bias
bert.encoder.layer.2.attention.output.LayerNorm.weight
bert.encoder.layer.2.attention.output.LayerNorm.bias
bert.encoder.layer.2.intermediate.dense.weight
bert.encoder.layer.2.intermediate.dense.bias
bert.encoder.layer.2.output.dense.weight
bert.encoder.layer.2.output.dense.bias
bert.encoder.layer.2.output.LayerNorm.weight
bert.encoder.layer.2.output.LayerNorm.bias
bert.encoder.layer.3.attention.self.query.weight
bert.encoder.layer.3.attention.self.query.bias
bert.encoder.layer.3.attention.self.key.weight
bert.encoder.layer.3.attention.self.key.bias
bert.encoder.layer.3.attention.self.value.weight
bert.encoder.layer.3.attention.self.value.bias
bert.encoder.layer.3.attention.output.dense.weight
bert.encoder.layer.3.attention.output.dense.bias
bert.encoder.layer.3.attention.output.LayerNorm.weight
bert.encoder.layer.3.attention.output.LayerNorm.bias
bert.encoder.layer.3.intermediate.dense.weight
bert.encoder.layer.3.intermediate.dense.bias
bert.encoder.layer.3.output.dense.weight
bert.encoder.layer.3.output.dense.bias
bert.encoder.layer.3.output.LayerNorm.weight
bert.encoder.layer.3.output.LayerNorm.bias
bert.encoder.layer.4.attention.self.query.weight
bert.encoder.layer.4.attention.self.query.bias
bert.encoder.layer.4.attention.self.key.weight
bert.encoder.layer.4.attention.self.key.bias
bert.encoder.layer.4.attention.self.value.weight
bert.encoder.layer.4.attention.self.value.bias
bert.encoder.layer.4.attention.output.dense.weight
bert.encoder.layer.4.attention.output.dense.bias
bert.encoder.layer.4.attention.output.LayerNorm.weight
bert.encoder.layer.4.attention.output.LayerNorm.bias
bert.encoder.layer.4.intermediate.dense.weight
bert.encoder.layer.4.intermediate.dense.bias
bert.encoder.layer.4.output.dense.weight
bert.encoder.layer.4.output.dense.bias
bert.encoder.layer.4.output.LayerNorm.weight
bert.encoder.layer.4.output.LayerNorm.bias
bert.encoder.layer.5.attention.self.query.weight
bert.encoder.layer.5.attention.self.query.bias
bert.encoder.layer.5.attention.self.key.weight
bert.encoder.layer.5.attention.self.key.bias
bert.encoder.layer.5.attention.self.value.weight
bert.encoder.layer.5.attention.self.value.bias
bert.encoder.layer.5.attention.output.dense.weight
bert.encoder.layer.5.attention.output.dense.bias
bert.encoder.layer.5.attention.output.LayerNorm.weight
bert.encoder.layer.5.attention.output.LayerNorm.bias
bert.encoder.layer.5.intermediate.dense.weight
bert.encoder.layer.5.intermediate.dense.bias
bert.encoder.layer.5.output.dense.weight
bert.encoder.layer.5.output.dense.bias
bert.encoder.layer.5.output.LayerNorm.weight
bert.encoder.layer.5.output.LayerNorm.bias
bert.encoder.layer.6.attention.self.query.weight
bert.encoder.layer.6.attention.self.query.bias
bert.encoder.layer.6.attention.self.key.weight
bert.encoder.layer.6.attention.self.key.bias
bert.encoder.layer.6.attention.self.value.weight
bert.encoder.layer.6.attention.self.value.bias
bert.encoder.layer.6.attention.output.dense.weight
bert.encoder.layer.6.attention.output.dense.bias
bert.encoder.layer.6.attention.output.LayerNorm.weight
bert.encoder.layer.6.attention.output.LayerNorm.bias
bert.encoder.layer.6.intermediate.dense.weight
bert.encoder.layer.6.intermediate.dense.bias
bert.encoder.layer.6.output.dense.weight
bert.encoder.layer.6.output.dense.bias
bert.encoder.layer.6.output.LayerNorm.weight
bert.encoder.layer.6.output.LayerNorm.bias
bert.encoder.layer.7.attention.self.query.weight
bert.encoder.layer.7.attention.self.query.bias
bert.encoder.layer.7.attention.self.key.weight
bert.encoder.layer.7.attention.self.key.bias
bert.encoder.layer.7.attention.self.value.weight
bert.encoder.layer.7.attention.self.value.bias
bert.encoder.layer.7.attention.output.dense.weight
bert.encoder.layer.7.attention.output.dense.bias
bert.encoder.layer.7.attention.output.LayerNorm.weight
bert.encoder.layer.7.attention.output.LayerNorm.bias
bert.encoder.layer.7.intermediate.dense.weight
bert.encoder.layer.7.intermediate.dense.bias
bert.encoder.layer.7.output.dense.weight
bert.encoder.layer.7.output.dense.bias
bert.encoder.layer.7.output.LayerNorm.weight
bert.encoder.layer.7.output.LayerNorm.bias
bert.encoder.layer.8.attention.self.query.weight
bert.encoder.layer.8.attention.self.query.bias
bert.encoder.layer.8.attention.self.key.weight
bert.encoder.layer.8.attention.self.key.bias
bert.encoder.layer.8.attention.self.value.weight
bert.encoder.layer.8.attention.self.value.bias
bert.encoder.layer.8.attention.output.dense.weight
bert.encoder.layer.8.attention.output.dense.bias
bert.encoder.layer.8.attention.output.LayerNorm.weight
bert.encoder.layer.8.attention.output.LayerNorm.bias
bert.encoder.layer.8.intermediate.dense.weight
bert.encoder.layer.8.intermediate.dense.bias
bert.encoder.layer.8.output.dense.weight
bert.encoder.layer.8.output.dense.bias
bert.encoder.layer.8.output.LayerNorm.weight
bert.encoder.layer.8.output.LayerNorm.bias
bert.encoder.layer.9.attention.self.query.weight
bert.encoder.layer.9.attention.self.query.bias
bert.encoder.layer.9.attention.self.key.weight
bert.encoder.layer.9.attention.self.key.bias
bert.encoder.layer.9.attention.self.value.weight
bert.encoder.layer.9.attention.self.value.bias
bert.encoder.layer.9.attention.output.dense.weight
bert.encoder.layer.9.attention.output.dense.bias
bert.encoder.layer.9.attention.output.LayerNorm.weight
bert.encoder.layer.9.attention.output.LayerNorm.bias
bert.encoder.layer.9.intermediate.dense.weight
bert.encoder.layer.9.intermediate.dense.bias
bert.encoder.layer.9.output.dense.weight
bert.encoder.layer.9.output.dense.bias
bert.encoder.layer.9.output.LayerNorm.weight
bert.encoder.layer.9.output.LayerNorm.bias
bert.encoder.layer.10.attention.self.query.weight
bert.encoder.layer.10.attention.self.query.bias
bert.encoder.layer.10.attention.self.key.weight
bert.encoder.layer.10.attention.self.key.bias
bert.encoder.layer.10.attention.self.value.weight
bert.encoder.layer.10.attention.self.value.bias
bert.encoder.layer.10.attention.output.dense.weight
bert.encoder.layer.10.attention.output.dense.bias
bert.encoder.layer.10.attention.output.LayerNorm.weight
bert.encoder.layer.10.attention.output.LayerNorm.bias
bert.encoder.layer.10.intermediate.dense.weight
bert.encoder.layer.10.intermediate.dense.bias
bert.encoder.layer.10.output.dense.weight
bert.encoder.layer.10.output.dense.bias
bert.encoder.layer.10.output.LayerNorm.weight
bert.encoder.layer.10.output.LayerNorm.bias
bert.encoder.layer.11.attention.self.query.weight
bert.encoder.layer.11.attention.self.query.bias
bert.encoder.layer.11.attention.self.key.weight
bert.encoder.layer.11.attention.self.key.bias
bert.encoder.layer.11.attention.self.value.weight
bert.encoder.layer.11.attention.self.value.bias
bert.encoder.layer.11.attention.output.dense.weight
bert.encoder.layer.11.attention.output.dense.bias
bert.encoder.layer.11.attention.output.LayerNorm.weight
bert.encoder.layer.11.attention.output.LayerNorm.bias
bert.encoder.layer.11.intermediate.dense.weight
bert.encoder.layer.11.intermediate.dense.bias
bert.encoder.layer.11.output.dense.weight
bert.encoder.layer.11.output.dense.bias
bert.encoder.layer.11.output.LayerNorm.weight
bert.encoder.layer.11.output.LayerNorm.bias
bert.pooler.dense.weight
bert.pooler.dense.bias
cls.predictions.bias
cls.predictions.transform.dense.weight
cls.predictions.transform.dense.bias
cls.predictions.transform.LayerNorm.gamma
cls.predictions.transform.LayerNorm.beta
cls.predictions.decoder.weight
cls.seq_relationship.weight
cls.seq_relationship.bias

评论 28
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值