Improving BERT Fine-Tuning via Self-Ensemble and Self-Distillation
1、基本信息
机构:复旦大学 华为公司
时间:2020
下载地址: https://arxiv.org/abs/2002.10345.
2、论文出发点
目前对于BERT效果的提升主要采用如下三种方式:
1)修改BERT模型结构;
2)重新设计预训练任务;
3)引入外部数据。
本文从微调的角度出发提出self-ensemble和self-distillation两种微调机制。
3、方法引入铺垫
目前存在以下两种常用方式。
1)投票方式
使用不同的随机数种子,微调多个BERT模型,将多个BERT模型的输出概率求和,集成模型的输出结果是概率最高的预测类别。(不太理解,个人对投票方式的理解是五个模型对样本进行预测,三个模型预测为正例,两个模型预测为负例,则该样本的预测类别为正),这种方式需要训练多个BERT模型,效率比较低
2)平均方式
微调多个BERT模型,将多个模型的参数求平均作为一个新的BERT模型的参数
4、本文方法
在现有方法的基础上,本文提出了自集成方式,在自集成方式的基础上又提出了Self-Distillation-Average(SDA)和Self-Distillation-Voted(SDV)两种方式
1)自集成方式
由于训练多个BERT模型会带来很多时间和计算资源的消耗,因此本文借鉴参数平均的方式提出了自集成的方式。
在自集成方式中包括两个模型,将其命名为模型A和模型B,模型B在训练过程中参数会不断更新,而模型A的参数是模型B在T个训练时间步内所有参数的平均值。假设T为2,模型B在经过第一个时间步以后,参数为[1,2,3,4],经过第二个时间步以后参数更新为[5,6,7,8],则经过2个时间步以后模型A的参数为[(1+5)/2,(2+6)/2,(3+7)/2,(4+8)/2)]。
2)Self-Distillation-Average(SDA)
本文在自集成方式的基础之上,进一步利用蒸馏的方式来提升模型效果。在该方法中同样包括两个模型,分别称作student model和 teacher model,其中student model损失函数的值由以下两部分损失函数值相加构成:
1)预测值和真实标签之间的交叉熵损失函数值
2)student model和teacher model输出值之间的均方误差值
其中teacher model的参数是student model在T个时间步内参数的平均值(此处即自集成方式)
疑惑:目前不是太明白student model和teacher model输出值之间的均方误差值起到什么作用??论文提到teacher model的参数是student model在各个时间步参数的平均值,因此具有更强的健壮性,
个人思考:teacher model的参数是student model在各个时间步的平均值,在损失函数中添加二者之间的均方误差值,是为了让 student model更加接近teacher model的效果,即各个时间步的平均效果。
3)Self-Distillation-Voted(SDV)
为例和SDA方法进行比较,本文又提出SDV方法,与SDA方法的不同之处在于,SDV方法将teacher model替换为自投票模型,student model和自投票模型之间的均方误差同样作为损失函数值得一部分。
5、实验
1)实验任务
在五个文本分类数据集和两个自然语言推理数据集中进行实验,详细如下图所示:
2)实验结果
具体实验结果如下图所示:
结论:在文本分类任务中,SDA比SDV方法效果更好,在自然语言推理任务中SDV则优于SDA