Bert学习笔记
参考文章【BERT原理详细介绍:https://blog.csdn.net/weixin_46425692/article/details/108927400】
一、组成结构
Transformer是一个encoder-decoder的结构,由6层编码器和6层解码器构成。其中encoder层包括输入层、注意力机制、前馈神经网络
BertBase是Transformer里面的12层encoder层组成的
二、Bert输入表示
bert将输入句子转化为词向量,是经过了3 个Embedding 的加和, 即
input_embed = Token_embed + Sentence_embed + Position_embed.
Token_embed:词嵌入WordVec2,如随机初始化、Glove初始化
Sentence_embed:用来区分两个句子的,如下面图中由两个句子,第1个句子就可以用0表示,第2个句子就用1表示.
Position_embed:随机初始化让模型学习,支持的序列长度最多为512个token。
设置时,如下所示:
tok_embed = nn.Embedding(vocab_size,d_model)
pos_embed = nn.Embedding(maxlen,d_model)
seg_embed = nn.Embedding(n_segments,d_model)
norm = nn.LayerNorm(d_model)
参数解释为:maxLen:批训练时,最大句子长度。d_model:词嵌入的维度,n_segments:表示输入多少句话,一般最大取2。
对于上面图,一些补充如下:
(1)[CLS] 可以在训练时对应的输出变量接一个二分类器输出分类结果,但不能代表句子语义信息。
(2)输入时[SEP]是一个分隔符,用来分隔两个句子
(3)有两个一起是因为后面的NSP任务会分析2个句子间的关系
如果用pytorch调用模型,只需要输入:
(1)input_ids:一个形状为[batch_size, sequence_length]的 torch.LongTensor,在词汇表中包含单词的token索引, 注意 在句子首尾分别加了 [cls] 和 [sep] 的 索引
(2)segment_ids :形状[batch_size, sequence_length]的可选 torch.LongTensor,在0, 1中选择token类型索引。类型0对应于句子A,类型1对应于句子B。如 [0,0,0,0,0,1,1,1,1,1], 0代表第一个句子A, 1代表第二个句子B,默认全为0
(3)input_mask:一个可选的 torch.LongTensor,形状为[batch_size, sequence_length],索引在0, 1中选择。0 是 padding 的位置,1是没有padding的字
三、做预训练:MLM+NSP
MLM / Masked LM - 掩码语言模型
Masked Language Model(MLM)是指在训练的时候随机从输入语料上mask掉一些单词,然后通过它的上下文预测该单词,该任务非常像我们在中学时期经常做的完形填空。
无监督目标函数包括有AR和AE。其中,AR用的是GRT,是一种自回归模型,是单侧的。AE是用于Bert,是一种自编码模型,可以用于上下文预测。
随机mask语料中15%的token,对于[Mask]这个符号,由于在测试集中不存在,为了减轻训练和预测之间的不匹配,作者按一定的比例在需要预测的token上动了手脚,如:my dog is hairy,则:在15%的单词当中
有80%的概率用“[mask]”标记来替换——my dog is [MASK]
有10%的概率用随机采样的一个单词来替换——my dog is apple
有10%的概率不做替换——my dog is hairy
NSP - 下一句预测(next sentence prediction)
1、从训练语料库中取两个连续文本为正样本
2、从不同文档中随机创建一堆段落为负样本
这里用的是CLS对应的相邻进行分类的
四、如何提升Bert下游任务效果
1、 一般做法:调用别人的Bert,再进行与训练
PreTrain
domain transfer :领域迁移,引入大量相同数据
Task transfer:找更相关的数据
Fine-tune:在任务相关数据上做具体任务(微调)
2、 在相同领域及逆行future pre-training
(1)动态mask:每次epoch去训练的时候mask
(2)n-gram mask:如ERNIE和SpanBert做的实体词的Bert
(3)参数:batch_size、learning_rate(5e-5,3e-5,2e-5)、number of epoch(3,4) 、weighted decay(修改后的adam)、使用warmup搭配线性衰减
其他:如数据增强EDA、自蒸馏、外部知识等
最后参考:【Bert代码详解(一)】https://blog.csdn.net/cpluss/article/details/88418176