单位:HuggingFace
时间:2020.5
发表:NIPS2019
论文链接:https://arxiv.org/pdf/1910.01108.pdf
一、背景
1. 什么是distill(蒸馏)?
蒸馏简单的说是将大模型(teacher)的学习结果,作为小模型(student)的学习目标,意在小模型能学习到大模型的表示。
蒸馏这个方法的核心思想是:好模型的目标不是拟合训练数据而是学习如何泛化到新的数据。
所以蒸馏的目标是让学生模型学习到教师模型的泛化能力,理论上得到的结果会比单纯拟合训练数据的学生模型要好。
2. BERT有哪些短板?
从应用落地的角度来说,bert虽然效果好,但有一个短板就是预训练模型太大,预测时间在平均在300ms以上(一条数据),无法满足线上并发量要求高的业务需求。
二、DistilBERT, a distilled version of BERT
1. 作者的思路
之前的模型蒸馏本质上都是两个loss,即distillation loss和student loss
这样模型学到的都是精调后的知识,即模型都是任务相关的,作者想蒸馏出一个任务无关的BERT,这样通用性更强,在具体任务时做具体的精调即可。
2. 具体做法
I. 模型结构
教师模型采用预训练好的BERT-base,学生模型则是6层的transformer。
II. 学生模型初始化方法
采用了BERT-PKD提出的PKD-skip的方式进行初始化,即用BERT-base的第[2,4,6,8,10]层的参数作为学生模型的参数。
III. Loss的设计
损失函数最终有三个,具体为:
-
MLM loss