[ACL 2022] BERT Learns to Teach: Knowledge Distillation with Meta Learning

[ACL 2022] BERT Learns to Teach: Knowledge Distillation with Meta Learning

Motivation

之前的工作通常训练一个大模型作为teacher,然后保持teacher不动训练student模型来完成teacher模型知识的迁移。然而这种做法有以下缺点:1)teacher不了解student的能力。近来有一些工作通过联合训练teacher和student来引入student-aware的蒸馏,但是这种方法也有提升的空间,因为2)teacher不是为了蒸馏优化的。之前的工作中,teacher通常都是为了自己的推理性能而优化的。然而,teacher不能感知到将其知识迁移给student的需求,这样不是最优的。为了解决这两个问题,我们提出了MetaDistil,一个新的使用meta-learning的蒸馏框架来利用学习过程中student的反馈来提升teacher的知识迁移能力。

Approach (Knowledge Distillation with Meta Learning)

Background

Knowledge Distillation

给定teacher T T T 和student S S S,以及有标注的训练数据集 D = { ( x 1 , y 1 ) , . . . , ( x N , y N ) } \mathcal{D} = \{(x_1,y_1), ..., (x_N, y_N)\} D={(x1,y1),...,(xN,yN)},student的损失函数可以写成下面形式:
L S ( D ; θ S ; θ T ) = 1 N ∑ i = 1 N [ α L τ ( y i , S ( x i ; θ S ) ) + ( 1 − α ) L K D ( T ( x i ; θ T ) , S ( x i ; θ S ) ) ] \begin{equation} \mathcal{L}_S(\mathcal{D};\theta_S;\theta_{T}) =\frac{1}{N}\sum_{i=1}^N[\alpha\mathcal{L}_{\tau}(y_i,S(x_i;\theta_S))\\ +(1-\alpha)\mathcal{L}_{KD}(T(x_i;\theta_T),S(x_i;\theta_S))] \end{equation} LS(D;θS;θT)=N1i=1N[αLτ(yi,S(xi;θS))+(1α)LKD(T(xi;θT),S(xi;θS))]
L τ \mathcal{L}_{\tau} Lτ是task-specific的loss, L K D \mathcal{L}_{KD} LKD是衡量teacher和student之间相似度的蒸馏损失,如teacher和student 概率分布的KL散度、logits的MSE等等。(此前有研究发现MSE比KL散度更稳定且效果略优,因此本文采用MSE)

Meta Learning

在涉及bi-level优化的meta learning算法中,存在inner-learner f i f_i fi和meta-learner f m f_m fm。inner-learner在meta-learner的帮助下被训练来完成任务 τ \tau τ f i f_i fi f m f_m fm 的帮助下进行训练的过程被叫做inner-loop,在inner-loop后更新的inner-learner记为 f i ′ ( f m ) f_i^{\prime}(f_m) fi(fm)。而meta-leaner通过meta目标优化,通常是最大化inner-loop之后inner-learner的预期性能。这个学习过程被叫做meta-loop,通常通过更新后的inner-learner在一些验证集上的损失 L ( f i ′ ( f m ) ) \mathcal{L}(f_i^{\prime}(f_m)) L(fi(fm))的导数进行梯度下降完成。

Methodology

在这里插入图片描述

Pilot Updata

本来的meta learning目的是学习一个好的meta-learner,能够泛化到不同的inner-learner来解决不同的任务。而在MetaDistil中,我们只关注inner-learner的表现。为了同步meta-和inner-learner的学习进度,我们设计了一个pilot update机制。
对一个batch的训练数据 x x x,我们复制了inner-learner f i f_i fi,然后在 x x x上更新复制的 f i ′ f_i^{\prime} fi和meta-learner f m f_m fm。然后。我们丢弃 f i ′ f_i^{\prime} fi并使用更新了的 f m f_m fm在同一份数据 x x x上再次更新 f i f_i fi。这个机制可以使用数据 x x x来同时影响 f i f_i fi f m f_m fm,因此对齐了训练过程。

Learning to Teach

在这里插入图片描述
在MetaDistil架构中,teacher的优化目标是:teacher蒸馏之后student的性能。student网络 θ S \theta_{S} θS 是inner-learner,teacher网络 θ T \theta_{T} θT 是meta-learner。在每个训练step中,我们首先将 θ S \theta_{S} θS 复制一份成 θ S ′ \theta_{S}^{\prime} θS。然后给一个batch数据以及学习率 λ \lambda λ θ S ′ \theta_{S}^{\prime} θS 通过传统的蒸馏算法更新:
θ S ′ ( θ T ) = θ S − λ ∇ θ S L S ( x ; θ S ; θ T ) \begin{equation} \theta_{S}^{\prime}(\theta_{T})=\theta_{S}-\lambda\nabla_{\theta_{S}}\mathcal{L}_S(x;\theta_{S};\theta_{T}) \end{equation} θS(θT)=θSλθSLS(x;θS;θT)
然后我们观察实验的student参数 θ S ′ \theta_{S}^{\prime} θS 和 student 在从quiz set上采样的一个batch数据的quiz loss l q = L τ ( q , θ S ′ ( θ T ) ) l_q=\mathcal{L}_{\tau}(q,\theta_{S}^{\prime}(\theta_T)) lq=Lτ(q,θS(θT)),它是一个关于teacher参数 θ T \theta_{T} θT 的函数。因此,我们可以对 θ T \theta_{T} θT 进行梯度下降来优化 l q l_q lq
θ T ← θ T − μ ∇ θ T L τ ( q ; θ S ′ ( θ T ) ) \begin{equation} \theta_T\larr \theta_T - \mu\nabla_{\theta_{T}}\mathcal{L}_{\tau}(q;\theta_{S}^{\prime}(\theta_T)) \end{equation} θTθTμθTLτ(q;θS(θT))
在我们的实验中,student没有在quiz set上训练,teacher只在quiz set上进行meta-update。我们也没有使用动态quiz set 策略,否则student将会在quiz set上训练过,这样loss就不会包含足够的信息了。在mate-update teacher之后,我们使用公式(2) 更新本来的student。使用公式(3) 优化teacher参数 θ T \theta_{T} θT是最大化被teacher蒸馏后的student期望表现。meta-objective 允许teacher模型调整参数来更好地将知识蒸馏给student。

Experiments

Settings

目标:BERT-base蒸馏6层的BERT,使用模型logits间的MSE loss作为蒸馏目标。在这里插入图片描述

Analysis

Why Does MetaDistill Work?

对于87%的更新,在进行真实更新后student的dev loss要小于没有进行pilot update的。这就说明了pilot update机制能更好地匹配student和teacher模型。
而且,我们发现前一半更新中,91%的更新让teacher与student更相似了(从logits的kl散度来看),这就表明teacher学着去适应低性能的student。在后一半更新中,这个百分比降到了63%。我们认为后面的训练过程中,teacher需要去超越student去自我进化来引导student进一步提升。

Hyper-parameter Sensitivity

在这里插入图片描述

Limitation

在这里插入图片描述
尽管MetaDistil达到最好性能的时间更久,但是它达到与PKD最好性能持平的时间基本相同。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值