模型的结构
模型主要由Embedding,多个Transform-Encoders模块,功能输出层组成。
Embeddings
BERT的输入将会编码成三种Embedding的单位和。
如图所示:
图中的两个特殊符号[CLS]和[SEP],其中[CLS]表示该特征用于分类模型,对非分类模型,该符合可以省去。[SEP]表示分句符号,用于断开输入语料中的两个句子。
- Position Embeddings:位置嵌入是指将单词的位置信息编码成特征向量,位置嵌入是向模型中引入单词位置关系的至关重要的一环。具体请参考Transforms模型中对于PositionEmbeddings的详细解释。
- Segment Embeddings:用于区分两个句子,例如B是否是A的下文(对话场景,问答场景等)。对于句子对,第一个句子的特征值是0,第二个句子的特征值是1。
- Token Embeddings:词向量,使用了WordPiece,是指将单词划分成一组有限的公共子词单元,能在单词的有效性和字符的灵活性之间取得一个折中的平衡。例如图4的示例中‘playing’被拆分成了‘play’和‘ing’。
Encoders
BERT的基础集成单元是Transformer的Encoder。
BERT的论文中介绍了2种版本 BERT:
, L=12, H=768, A=12,总参数=110M
: L=24, H=1024, A=16, 总参数=340M
层数(即 Transformer-Encoder 块个数)表示为 L,将隐藏尺寸表示为 H、自注意力头数表示为 A。在所有实验中,将前馈/滤波器尺寸设置为 4H,即 H=768 时为 3072,H=1024 时为 4096.
将Embeddings输入到Encoders中:
每个位置返回的输出都是一个隐藏层大小的向量(基本版本BERT为768)。以文本分类为例,我们重点关注第一个位置上的输出(第一个位置是分类标识[CLS])
功能输出层
这一层是根据不同的任务来进行调整,以分类任务为例:
模型训练
BERT是一个预训练多任务模型,它的任务是由两个自监督任务组成,即MLM和NSP。
预训练:假设已有A训练集,先用A对网络进行预训练,在A任务上学会网络参数,然后保存以备后用,当来一个新的任务B,采取相同的网络结构,网络参数初始化的时候可以加载A学习好的参数,其他的高层参数随机初始化,之后用B任务的训练数据来训练网络,当加载的参数保持不变时,称为"frozen",当加载的参数随着B任务的训练进行不断的改变,称为“fine-tuning”,即更好地把参数进行调整使得更适合当前的B任务
优点:当任务B的训练数据较少时,很难很好的训练网络,但是获得了A训练的参数,会比仅仅使用B训练的参数更优。
Masked Language Model
Masked Language Model(MLM)和核心思想取自Wilson Taylor在1953年发表的一篇论文[7]。所谓MLM是指在训练的时候随即从输入预料上mask掉一些单词,然后通过的上下文预测该单词,该任务非常像我们在中学时期经常做的完形填空。正如传统的语言模型算法和RNN匹配那样,MLM的这个性质和Transformer的结构是非常匹配的。
如图所示:
在BERT的实验中,15%的WordPiece Token会被随机Mask掉。在训练模型时,一个句子会被多次喂到模型中用于参数学习,但是Google并没有在每次都mask掉这些单词,而是在确定要Mask掉的单词之后,80%的时候会直接替换为[Mask],10%的时候将其替换为其它任意单词,10%的时候会保留原始Token。
80%:my dog is hairy -> my dog is [mask]
10%:my dog is hairy -> my dog is apple
10%:my dog is hairy -> my dog is hairy
这么做的原因是如果句子中的某个Token100%都会被mask掉,那么在fine-tuning的时候模型就会有一些没有见过的单词。加入随机Token的原因是因为Transformer要保持对每个输入token的分布式表征,否则模型就会记住这个[mask]是token ’hairy‘。至于单词带来的负面影响,因为一个单词被随机替换掉的概率只有15%*10% =1.5%,这个负面影响其实是可以忽略不计的。
Next Sentence Prediction
Next Sentence Prediction(NSP)的任务是判断句子B是否是句子A的下文。如果是的话输出’IsNext‘,否则输出’NotNext‘。训练数据的生成方式是从平行语料中随机抽取的连续两句话,其中50%保留抽取的两句话,它们符合IsNext关系,另外50%的第二句话是随机从预料中提取的,它们的关系是NotNext的。这个关系保存在[CLS]符号中。
如图所示:
Input = SentenceA:[CLS] the man went to [MASK] store [SEP]
SentenceB:penguin [MASK] are flight ## less birds [SEP]
Label = NotNext
Input =SentenceA:[CLS] the man went to [MASK] store [SEP]
SentenceB:he bought a gallon [MASK] milk [SEP]
Label = IsNext
这样训练模型,使模型具备理解长序列上下文的联系的能力