DeFormer: Decomposing Pre-trained Transformers for Faster Question Answering
https://github.com/StonyBrookNLP/deformer
Motivation
单塔模型运行速度慢,并且内存密集,本文引入DeFormer,它在较低的层单独处理question和passages,这允许预先计算段落表示,从而大大减少运行时计算。
1.介绍
根据这个事实:
- 预训练模型较低的层往往侧重于local现象,如句法方面,而较高的层次侧重于global现象,如与目标任务相关的语义方面。
2.在一个标准的基于BERT的问答模型中,当改变问题时,文本的低层表示中的差异较小。这意味着在较低的层中,来自问题的信息对于形成文本表示并不重要。
综上所述,这些都表明在Transformer的较低层只考虑局部环境,在较高层考虑全局环境,在效率方面可以以非常小的代价提供加速。因此出现了本文的DeFormer(较低层独立处理问题和上下文文本,较高层共同处理它们):
假设允许
n
n
n层模型中的前
k
k
k个层独立处理问题和上下文文本,并缓存第
k
k
k层的输出。在运行时,首先通过模型的前
k
k
k层处理问题,第
k
k
k层的文本表示从缓存中加载。这两个第
k
k
k层表示作为输入被馈送到第(
k
k
k + 1)层,并且进一步的处理像在原始模型中一样通过更高层继续。
通过从原始模型中学习,可以进一步降低这种精度损失。但是原文希望Transformer的行为更像原始模型。具体来说,Transformer的上层应该生成与原始模型中相应层捕获相同类型信息的表示。为此增加了两个类似蒸馏的辅助损失,这使得解耦后的模型和原始模型之间的输出级和层级差异最小化。
2 Decomposing Transformers for Faster Inference
为了评估问题token在较低层对形成文本表示的影响,本文测量了文本表达在与不同问题配对时的变化。特别是,当与不同的问题配对时,计算平均段落表示方差。使用通过向量和它们的质心之间的余弦距离来测量方差。如下图所示:
在较低的层中,文本表示不像在较高的层中那样变化很大,这表明忽略较低层中问题标记的注意力可能不是一个坏主意。
2.1 DeFormer
对于句子对任务,文本
T
a
T_a
Ta,
T
b
T_b
Tb,token表示:
T
a
T_a
Ta:
A
A
A =[
a
1
a_1
a1;
a
2
a_2
a2;…;
a
q
a_q
aq]
T
b
T_b
Tb:
B
B
B = [
b
1
b_1
b1;
b
2
b_2
b2;…;
b
p
b_p
bp]
对于有
n
n
n层(
L
i
L_i
Li表示第
i
i
i层)的Transformer:
-
cross: X X X=[ A A A; B B B], X X X l ^l l + ^+ + 1 ^1 1= L L L i _i i( X l X^l Xl)
-
DeFormer(前k层独立处理 A A A和 B B B,后几次联合处理):
复杂度分析: O O O(( p p p+ q q q) 2 ^2 2) 到 O O O( q 2 q^2 q2+ c c c), c c c表示加载缓存的成本。
2.2 Auxiliary Supervision for DeFormer
由于DeFormer保留了大部分原始结构,因此可以用原始Transformer的预训练权重初始化这个模型,并直接在下游任务上进行微调。DeFormer在较低层的表示中丢失了一些信息。虽然上层可以学习在微调过程中对此进行补偿。但是可以进一步改进:使用原始模型行为作为额外的监督来源。
具体来说:首先用预训练的Transformer的参数初始化DeFormer的参数,并在下游任务中对其进行微调。并添加了辅助loss,使DeFormer预测与上层表示更接近完整的预训练Transformer的预测和相应的层表示。
Knowledge Distillation Loss:
我们希望DeFormer的预测分布更接近完整的预训练Transformer的预测分布。通过最小化DeFormer预测分布
P
A
P_A
PA和完整的预训练Transformer的预测分布
P
B
P_B
PB之间的
K
K
K
L
L
L散度:
Layerwise Representation Similarity Loss:
我们希望DeFormer的上层表示更接近完整的预训练Transformer。通过最小化DeFormer和完整的预训练Transformer的上层的token表示之间的欧几里得距离。让
v
v
v
i
_i
i
j
^j
j表示在全Transformer的第
i
i
i层中第
j
j
j个token,让
u
u
u
i
_i
i
j
_j
j表示相应的DeFormer表示:
对于上面
k
k
k + 1到
n
n
n中的每一层,计算层表示相似性(
l
l
l
r
r
r
s
s
s)损失如下:
最后将知识蒸馏损失(
L
L
L
k
_k
k
d
_d
d)和层表示相似性损失(
L
L
L
l
_l
l
r
_r
r
s
_s
s)与特定任务监督损失(
L
L
L
t
_t
t
s
_s
s)相加,并通过超参数调整学习它们的相对重要性:
3 Evaluation
3.1 Datasets
预训练Transformer:BERT-uncase-base 和 large
SQuAD v1.1(QA)
RACE(QA)
BoolQ(QA)
MNLI(sentence pairs)
QQP(sentence pairs)
3.3 Results
下表显示了当使用九个下层和三个上层时的结果:
BERT、BERT-large和DeFormer-BERT-large的性能、速度和内存:
推理时间分析
BERT-base与DeFormer-BERT-base的SQuAD数据集上的推理延迟(以秒为单位),作为在批处理模式下测量的平均值。在GPU和CPU上批量大小为32,在手机上(用*)批量大小为1。
消融
蒸馏损失和表示相似度损失影响:
分离层
k
k
k选择的影响(SQuAD上):
base-bert:3层后F1开始下降
large-bert:13层后F1开始下降,并还有一段上升趋势。
下图显示了问题和文章在不同层的平均距离。两个模型的文章和问题的较低层表示保持相似,但较高层表示有显著差异,这支持了缺乏交叉关注在较低层比在较高层影响更小的观点:
加上辅助loss后平均欧式距离在高层后有明显改进。