[ACL 2022] BERT Learns to Teach: Knowledge Distillation with Meta Learning
Motivation
之前的工作通常训练一个大模型作为teacher,然后保持teacher不动训练student模型来完成teacher模型知识的迁移。然而这种做法有以下缺点:1)teacher不了解student的能力。近来有一些工作通过联合训练teacher和student来引入student-aware的蒸馏,但是这种方法也有提升的空间,因为2)teacher不是为了蒸馏优化的。之前的工作中,teacher通常都是为了自己的推理性能而优化的。然而,teacher不能感知到将其知识迁移给student的需求,这样不是最优的。为了解决这两个问题,我们提出了MetaDistil,一个新的使用meta-learning的蒸馏框架来利用学习过程中student的反馈来提升teacher的知识迁移能力。
Approach (Knowledge Distillation with Meta Learning)
Background
Knowledge Distillation
给定teacher
T
T
T 和student
S
S
S,以及有标注的训练数据集
D
=
{
(
x
1
,
y
1
)
,
.
.
.
,
(
x
N
,
y
N
)
}
\mathcal{D} = \{(x_1,y_1), ..., (x_N, y_N)\}
D={(x1,y1),...,(xN,yN)},student的损失函数可以写成下面形式:
L
S
(
D
;
θ
S
;
θ
T
)
=
1
N
∑
i
=
1
N
[
α
L
τ
(
y
i
,
S
(
x
i
;
θ
S
)
)
+
(
1
−
α
)
L
K
D
(
T
(
x
i
;
θ
T
)
,
S
(
x
i
;
θ
S
)
)
]
\begin{equation} \mathcal{L}_S(\mathcal{D};\theta_S;\theta_{T}) =\frac{1}{N}\sum_{i=1}^N[\alpha\mathcal{L}_{\tau}(y_i,S(x_i;\theta_S))\\ +(1-\alpha)\mathcal{L}_{KD}(T(x_i;\theta_T),S(x_i;\theta_S))] \end{equation}
LS(D;θS;θT)=N1i=1∑N[αLτ(yi,S(xi;θS))+(1−α)LKD(T(xi;θT),S(xi;θS))]
L
τ
\mathcal{L}_{\tau}
Lτ是task-specific的loss,
L
K
D
\mathcal{L}_{KD}
LKD是衡量teacher和student之间相似度的蒸馏损失,如teacher和student 概率分布的KL散度、logits的MSE等等。(此前有研究发现MSE比KL散度更稳定且效果略优,因此本文采用MSE)
Meta Learning
在涉及bi-level优化的meta learning算法中,存在inner-learner f i f_i fi和meta-learner f m f_m fm。inner-learner在meta-learner的帮助下被训练来完成任务 τ \tau τ。 f i f_i fi 在 f m f_m fm 的帮助下进行训练的过程被叫做inner-loop,在inner-loop后更新的inner-learner记为 f i ′ ( f m ) f_i^{\prime}(f_m) fi′(fm)。而meta-leaner通过meta目标优化,通常是最大化inner-loop之后inner-learner的预期性能。这个学习过程被叫做meta-loop,通常通过更新后的inner-learner在一些验证集上的损失 L ( f i ′ ( f m ) ) \mathcal{L}(f_i^{\prime}(f_m)) L(fi′(fm))的导数进行梯度下降完成。
Methodology
Pilot Updata
本来的meta learning目的是学习一个好的meta-learner,能够泛化到不同的inner-learner来解决不同的任务。而在MetaDistil中,我们只关注inner-learner的表现。为了同步meta-和inner-learner的学习进度,我们设计了一个pilot update机制。
对一个batch的训练数据
x
x
x,我们复制了inner-learner
f
i
f_i
fi,然后在
x
x
x上更新复制的
f
i
′
f_i^{\prime}
fi′和meta-learner
f
m
f_m
fm。然后。我们丢弃
f
i
′
f_i^{\prime}
fi′并使用更新了的
f
m
f_m
fm在同一份数据
x
x
x上再次更新
f
i
f_i
fi。这个机制可以使用数据
x
x
x来同时影响
f
i
f_i
fi和
f
m
f_m
fm,因此对齐了训练过程。
Learning to Teach
在MetaDistil架构中,teacher的优化目标是:teacher蒸馏之后student的性能。student网络
θ
S
\theta_{S}
θS 是inner-learner,teacher网络
θ
T
\theta_{T}
θT 是meta-learner。在每个训练step中,我们首先将
θ
S
\theta_{S}
θS 复制一份成
θ
S
′
\theta_{S}^{\prime}
θS′。然后给一个batch数据以及学习率
λ
\lambda
λ,
θ
S
′
\theta_{S}^{\prime}
θS′ 通过传统的蒸馏算法更新:
θ
S
′
(
θ
T
)
=
θ
S
−
λ
∇
θ
S
L
S
(
x
;
θ
S
;
θ
T
)
\begin{equation} \theta_{S}^{\prime}(\theta_{T})=\theta_{S}-\lambda\nabla_{\theta_{S}}\mathcal{L}_S(x;\theta_{S};\theta_{T}) \end{equation}
θS′(θT)=θS−λ∇θSLS(x;θS;θT)
然后我们观察实验的student参数
θ
S
′
\theta_{S}^{\prime}
θS′ 和 student 在从quiz set上采样的一个batch数据的quiz loss
l
q
=
L
τ
(
q
,
θ
S
′
(
θ
T
)
)
l_q=\mathcal{L}_{\tau}(q,\theta_{S}^{\prime}(\theta_T))
lq=Lτ(q,θS′(θT)),它是一个关于teacher参数
θ
T
\theta_{T}
θT 的函数。因此,我们可以对
θ
T
\theta_{T}
θT 进行梯度下降来优化
l
q
l_q
lq:
θ
T
←
θ
T
−
μ
∇
θ
T
L
τ
(
q
;
θ
S
′
(
θ
T
)
)
\begin{equation} \theta_T\larr \theta_T - \mu\nabla_{\theta_{T}}\mathcal{L}_{\tau}(q;\theta_{S}^{\prime}(\theta_T)) \end{equation}
θT←θT−μ∇θTLτ(q;θS′(θT))
在我们的实验中,student没有在quiz set上训练,teacher只在quiz set上进行meta-update。我们也没有使用动态quiz set 策略,否则student将会在quiz set上训练过,这样loss就不会包含足够的信息了。在mate-update teacher之后,我们使用公式(2) 更新本来的student。使用公式(3) 优化teacher参数
θ
T
\theta_{T}
θT是最大化被teacher蒸馏后的student期望表现。meta-objective 允许teacher模型调整参数来更好地将知识蒸馏给student。
Experiments
Settings
目标:BERT-base蒸馏6层的BERT,使用模型logits间的MSE loss作为蒸馏目标。
Analysis
Why Does MetaDistill Work?
对于87%的更新,在进行真实更新后student的dev loss要小于没有进行pilot update的。这就说明了pilot update机制能更好地匹配student和teacher模型。
而且,我们发现前一半更新中,91%的更新让teacher与student更相似了(从logits的kl散度来看),这就表明teacher学着去适应低性能的student。在后一半更新中,这个百分比降到了63%。我们认为后面的训练过程中,teacher需要去超越student去自我进化来引导student进一步提升。
Hyper-parameter Sensitivity
Limitation
尽管MetaDistil达到最好性能的时间更久,但是它达到与PKD最好性能持平的时间基本相同。