转载来源:https://zhuanlan.zhihu.com/p/143027221
自从 BERT 出现后,似乎 NLP 就走上了大力出奇迹的道路。模型越来越大参数越来越多,这直接导致我们需要的资源和时间也越来越多。发文章搞科研似乎没有什么,但是这些大模型很难在实际工业场景落地,不只是因为成本过高,也因为推理速度不支持线上实际情况。最近好多文章都开始针对 BERT 进行瘦身,不管是蒸馏还是减少层数还是参数共享等,都是为了 BERT 系列模型能够更小更快的同时不丢失精度。
论文地址:https://arxiv.org/pdf/2004.02178.pdfarxiv.org
一、概述
FastBERT 是 ACL2020 新鲜出炉的一篇关于提高 BERT 推理速度的文章,这篇文章的思想还是很巧妙的,作者发现 12 层的 BERT 用来对一些简单问题做分类有点大材小用了。那么我们可以不可对于简单问题只用较少层数来解答,对于复杂问题采用全部层数的 BERT 呢?作者在这里对 BERT 的每层输出后面都接了分类器,如果在浅层模型就有很高的置信度对样本进行分类,那么久不再走后面的层了,如果置信度不高,那就继续走后面,这样就极大地缩短了推理的时间。这里每层的分类器并不是单独训练的,而是使用一个总的 teacher classifier 蒸馏出来的,从实验结果可以看出采用自蒸馏要比直接训练效果好。
二、模型详解
BackBone
整个模型的骨架就是采用 BERT 来实现,所有 BERT 系列的模型都可以套用在这里。BERT 在这里的作用还是一个强大的特征提取器的作用,在多层 transformer 堆叠的后面紧跟着一个 teacher classifier,这个在 fine-tune 阶段会进行训练,后面用来蒸馏每一层的 student classifier。整体的模型结构见下图:
图里面的 Branch 就是每个 student classifier,它们具有和 teacher classifier 一样的结构。在实际推理的时候,从底层开始往上,如果有很高的置信度就 early output,不再仅需往后面走了。
Model Training
模型的训练一共包括三个部分,一个是主要骨架模型的预训练,这里和传统的 BERT 模型是一样的;然后就是整个骨架模型的 fine-tuning,这里会训练 teacher classifier;最后是对 teacher classifier 进行蒸馏的到 student classifier。预训练和 fine-tuning 没啥好说的,和 BERT 是一摸一样的,这里主要介绍一下 self-distillation。自蒸馏和传统的蒸馏方式最大的不同就是 teacher 模型和 student 模型是一样的在一个模型里面,传统的方式往往需要单独设计 student 模型。见下图:
但是自蒸馏的话就不存在这个问题了,自蒸馏使用 teacher classifier 的输出 Pt 以及 student 的输出 Ps,然后计算他们的 KL 散度,通过优化所有 student KL 散度 loss 的合来确保 student 和 teacher 的分布越来越相似。具体公式如下:
D
K
L
(
p
s
,
p
t
)
=
∑
i
=
1
N
p
s
(
i
)
⋅
log
p
s
(
i
)
p
t
(
j
)
D_{K L}\left(p_{s}, p_{t}\right)=\sum_{i=1}^{N} p_{s}(i) \cdot \log \frac{p_{s}(i)}{p_{t}(j)}
DKL(ps,pt)=i=1∑Nps(i)⋅logpt(j)ps(i)
Loss
(
p
s
0
,
…
,
p
s
L
−
2
,
p
t
)
=
∑
i
=
0
L
−
2
D
K
L
(
p
s
i
,
p
t
)
\operatorname{Loss}\left(p_{s_{0}}, \ldots, p_{s_{L-2}}, p_{t}\right)=\sum_{i=0}^{L-2} D_{K L}\left(p_{s_{i}}, p_{t}\right)
Loss(ps0,…,psL−2,pt)=i=0∑L−2DKL(psi,pt)
Adaptive Inference
在推理阶段作者使用了自适应的推理,简单来说就是每层的 BERT 都会有一个结果,作者定义了一个输出结果不确定性的度量,用来衡量每层输出的结果是否可信,公式如下:
Uncertainty
=
∑
i
=
1
N
p
s
(
i
)
log
p
s
(
i
)
log
1
N
\text {Uncertainty}=\frac{\sum_{i=1}^{N} p_{s}(i) \log p_{s}(i)}{\log \frac{1}{N}}
Uncertainty=logN1∑i=1Nps(i)logps(i)
根据公式我们可以发现,这个不确定性就是用熵来衡量的。熵越大代表结果越不可信,如果某一层的不确定性小于一个阈值,那么我们就对这层的结果进行输出,从而提高了推理速度。我们可能会发现一个问题,如果在浅层的准确率很低,后面也就没办法了。所以个人感觉这个模型效果好的原因应该是基于一个假设:“数据集中大部分样本都是简单样本”。这里我有点想知道,如果只用 teacher 来整理一个最底层的 student,不知道效果会咋样。
三、实验
作者在 12 个数据集上对模型效果进行了对比,除了传统的 BERT,也对比了 DistilBERT。结果如下:
从实验结果中可以看出来,FastBERT 在提升速度的同时对于精度的减少还是比较小的。但是这个实验美中不足的是在一些数据集上 FASTBERT 和 DistilBERT 并没有压缩到同等的数量级,这在一定程度上没法真实的比较两个模型的效果。
作者通过实验证明了自蒸馏的效果,结果如下:
通过上图可以发现,加入自蒸馏后,计算复杂度下降了非常多,但是 Acc 几乎没有下降。同时文中提到的假设:不确定性低,准确率高。作者也通过实验证明了这一假设:
同时作者也给出了每层不确定性分布的实验,结果如下:
通过这个实验结果结合不确定性低准确率高的假设,我们就可以发现,其实很多样本在浅层就有很低的不确定性,所以完全没有必要继续走到后面的层去进行分类。
三、结论
作者通过充分利用 BERT 每层的输出,发现很多简单的样本其实在浅层就已经可以很好的分类从而不需要走到最后,极大地提高了 BERT 的推理速度。文章的思路还是非常的新奇的,不是传统的减少模型参数量来提高速度的方法,而是将训练中的 early stop 思想用到了推理上。同时自蒸馏的想法也很有启发性,避免了传统蒸馏方式需要自己设计 student model 的情况。