Distilling transformers into simple neural networks with unlabeled transfer data
论文地址:https://arxiv.org/pdf/1910.01769.pdf
motivation
一般来说,蒸馏得到的student模型与teacher模型的准确率还存在差距。文章利用大量in-domain unlabeled transfer set以及有限数量的标记训练实例来弥补这一差距。
运用在文本分类任务上。
文章提出的两个蒸馏方法:
- hard distillation:用fine-tuned teacher对大量的无标签数据进行标注,标注硬标签。然后用这些augmented data来对student进行监督学习。loss函数是交叉墒。
- soft distillation:用教师模型在unlabeled data上生成的logits和内部表示,来对student进行不同蒸馏方式(不同loss函数)的训练。
模型的输入用Wordpiece tokenization
student模型
embedding层+BiLSTM层+最大池化层(因为如果只用最后一个hidden state,对于长语句来说信息量不够)
损失函数为交叉墒函数
teacher模型
用标注数据来fine tune预训练模型,用的是最后一层的 [CLS]向量,loss是交叉墒。
选什么特征来蒸馏
teacher logits
对于unlabeled数据,教师产生的logits和学生生成的分类score之间的loss,均方误差。
hidden teacher representations
用教师学到的中间表示来指导学生模仿自己。文中用的是教师模型最后一层。
因为两个模型结构不同,最后一层的维度也会不同。用Gelu激活函数进行转换到相同维度。
然后依然是两个模型最后一层表示之间的均方误差作为loss。
(文中提到,他们发现均方误差得到的结果比KL散度更好些)
损失函数:
三个loss函数组合起来,不同的权重。
较高的α值使学生模型更多地关注容易实现的目标。而较高的γ可使学生专注于困难的目标,并使模型适应嘈杂的地面真相标签。后者不是这项工作的重点,因此在进一步分析中将其省略。
训练方式:
- 联合训练,三个loss函数并在一起,
- 逐渐解冻的分层训练。
- 第一步,先训练LRL ,学习参数,模仿teacher最后一层的表示。
- 第二步,以LCE和LLL作为loss,但是不能一下子优化所有参数,会造成灾难性遗忘。因此,将每层的参数frozen,然后从最后一层一层的解冻。直到收敛。
- 先蒸馏,再finetune,与方式2相似,
- 第一步,先以LRL 和LLL作为loss,不需要labeled数据。
- 第二步,用labeled数据进行fine tune,loss为LCE,微调的时候和方式2一样,逐层解冻。
这样相当于第一步得到了一个蒸馏后的student,然后之后就可以根据不同的任务数据来fine tune它。
实验
四个数据集:
- IMDB:电影评论情感分类
- Elec:亚马逊电子产品的情感分类
- DbPedia:Wikipedia的主题分类
- Ag News:新闻文章的主题分类。
一些参数
数据比例 train:validation=9:1
Tensorflow
4 Tesla V100 gpus
Adadelta优化器 + early stopping (也用了Adam,Adam收敛更快,但是最终结果没有Adadelta好)
所有层dropout=0.4,Bi-LSTM层dropout=0.2
300d的glove预训练词向量。
LSTM隐层维度=600,batch size=64,
loss函数中的 α = β = 10, γ = 1
教师模型
选的是BERT-base和BERT-large
学生模型:
- Bi-LSTM encoder,最后一个隐向量+soft Max作分类,用基础的空格分词法。只用labeled data,交叉墒为损失函数来训练。
- 不蒸馏的Bi-LSTM encoder+Max pooling,用wordpiece tokenization,训练loss和1相同。
- 蒸馏的student。和上文中提到的一样,三个loss,3种不同的训练方式。
数据处理部分:
有的数据集没有unlabeled data,所以就把数据集分为两部分,一部分有标签,另一部分去掉标签作为unlabeled data。
实验结果:
用BERT-base做老师
可以看出用了wordpiece和加了Max pooling层的模型比普通的RNN模型效果好。然后通过蒸馏可以明显的提升学生模型的准确率,甚至高于teacher。
用更大的BERT-large做老师:
可以看出,teacher性能越好,蒸馏得到的student也越好。(符合思维,好老师教出好学生)
参数量比较
Distilled Student | BERT Base | BERT Large |
---|---|---|
13M | 110M | 340M |
Distilling Hard Targets vs. Soft Logits
hard distillation是指,finetune之后的teacher对unlabeled data进行预测标注,然后用原本的标注数据和teacher标注的数据一起,对student进行蒸馏。不涉及到logits和最后层表示。
Distillation with Less Training Labels
- 每类留500个标注数据的蒸馏结果:
- 每类留100个标注数据:
每类100个,对BERT large进行fine tune时,DbPedia和Ag News两个数据集的微调结果和之前500个的时候差不多,但是IMDB和Elec数据集只有50%的准确率,几乎是随机了。
于是文章又做了一个实验,用这两个数据集对pretrained BERT接着进行预训练,然后用每类100个进行微调。最后再用每类100个进行蒸馏,得到的学生模型甚至超过了BERT large。
不同训练方式的对比:
可以看出,方式2的结果最好,先学习最后一层的表示,再学习logits和CE。
总结
这篇论文的创新性不高,思路其实和Distilling Task-Specific Knowledge from BERT into Simple Neural Networks这篇论文差不多。
但是文章里的实验做的很详细,该比较的点都比较了。论文也写的很清晰易懂。
可以看出蒸馏到小模型确实是Bert这些大模型很好的应用方向。自己也做了相关的实验,确实有效。