TinyBERT-模型蒸馏

本文介绍下TinyBERT,华为在2020发布的一篇论文,主要内容是对模型进行蒸馏,蒸馏的方法值得学习

论文信息


论文地址:

https://arxiv.org/abs/1909.10351

代码地址:

https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/TinyBERT

主要内容


目前已经有很多模型压缩的技术,如矩阵分解,量化、权重共享、剪枝、以及知识蒸馏,本文的重点在于知识蒸馏。

如上图所示,tinybert的蒸馏步骤可以概括为 General DistillationTask-specific Distillation,也就是在大规模语料上对通识知识的蒸馏,这是在预训练阶段的蒸馏,和在指定任务数据上对特定任务的知识蒸馏,并且用通识知识的蒸馏模型对指定任务的蒸馏模型进行初始化,这是在微调阶段的蒸馏。同时,在特征任务上进行知识蒸馏时,会先对数据进行增强。

作者的实验结果是,4层的tinybert可以达到bertbase的96.8%的效果,但是参数量为bertbase的13.3%,推理时间为10.6%,并且比其他蒸馏的效果要好,同时,6层的tinybert和bertbase的表现近似。

1、transformer distillation

对transformer网络层数的蒸馏。假设学生模型有M层,老师模型有N层,自定义一个map函数 n = g ( m ) n=g(m) n=g(m),实现学生层到老师层的map,表示学生模型的第m层从老师模型的第g(m)层学得信息。损失函数如下:
L m o d e l = ∑ x ∈ X ∑ m = 0 M + 1 λ m L l a y e r ( f m S ( x ) , f g ( m ) T ( x ) ) L_{model}=\sum_{x\in_{X}}\sum_{m=0}^{M+1}\lambda_mL_{layer}(f_m^S(x),f_{g(m)}^T(x)) Lmodel=xXm=0M+1λmLlayer(fmS(x),fg(m)T(x))

L l a y e r L_{layer} Llayer表示的是某一个transformer layer或者是embedding layer的损失函数, f m ( x ) f_m(x) fm(x)表示第m层的目标函数值, λ m \lambda_m λm表示第m层的重要性,为超参数。

transformer distillation包括attention distill、hidden distill、embedding distill、以及prediction distill,如下图所示:

  • attention

其中,attention distill的目标函数为:

L a t t n = 1 h ∑ i = 1 h M S E ( A i S , A i T ) L_{attn}=\frac{1}{h}\sum_{i=1}^hMSE(A_i^S,A_i^T) Lattn=h1i=1hMSE(AiS,AiT)

h表示注意力头的个数, A i A_i Ai表示学生或老师第i个注意力头的attention matrix

同时,作者表明,之所以使用 A i A_i Ai,而不是 s o f t m a x ( A i ) softmax(A_i) softmax(Ai)作为拟合目标,是因为前者的收敛更快,效果更好。

  • hidden

其中,transformer输出蒸馏的损失函数为:

L h i d n = M S E ( H S W h , H T ) L_{hidn}=MSE(H^SW_h,H^T) Lhidn=MSE(HSWh,HT)

其中 H S ∈ R l × d ∗ H^S\in{R^{l \times d^*}} HSRl×d, H T ∈ R l × d H^T\in{R^{l \times d}} HTRl×d
d ∗ d^* d表示学生模型的向量维度。 W h W_h Wh是一个可学习矩阵,用来对学生模型进行线性变化,将其转化为与老师模型相同的维度。

  • embedding

embedding层输出蒸馏的损失函数为:

L e m b d = M S E ( E S W e , E T ) L_{embd}=MSE(E^SW_e,E^T) Lembd=MSE(ESWe,ET)

可以看到基本与transformer输出的蒸馏形式是一样的。

  • prediction
    损失函数为:

L p r e d = C E ( z T t , z S t ) L_{pred}=CE(\frac{z^T}{t},\frac{z^S}{t}) Lpred=CE(tzT,tzS)

z表示logits,t表示温度系数,作者实验发现,t=1时效果最好。这部分的损失函数就和distillbert设计的蒸馏损失比较像.

整体模型的损失函数如下:
L l a y e r = { L e m b d m = 0 L h i d n + L a t t n m ∈ ( 0 , M ] L p r e d m = M + 1 L_{layer}=\begin{cases} L_{embd} & m=0 \\ L_{hidn}+L_{attn} & m\in(0,M] \\ L_{pred} & m=M+1 \end{cases} Llayer=LembdLhidn+LattnLpredm=0m(0,M]m=M+1

其中,m表示学生的层数

2、task-specific distillation

该部分先对数据集进行增强,然后进行蒸馏。作者对数据增强的解释为,学生模型在经过增强的数据集上进行训练,可以提高其效果,也就是说,相比于老师模型,学生模型在特定任务上的训练数据是经过增强的,以此来提升学生模型的效果,因此学生就有超过老师的可能。

数据增强
其伪代码如下:

作者结合bert和glove的词嵌入,在word-level上进行替换,以实现数据增强。作者的参数设置如下, p t = 0.4 p_t=0.4 pt=0.4 N a = 20 N_a=20 Na=20 K = 15 K=15 K=15

论文并没有对task-specific distillation的蒸馏部分进行阐述,说明其与general distill的蒸馏方式应该是一样的,只是一个处于预训练阶段,一个处于微调阶段。

3、实验结果

实验时,作者使用 g ( m ) = 3 × m g(m)=3 \times m g(m)=3×m进行映射,也就是说4层的tinybert的每层都是从3层的bertbase中学得。

下面是作者对tinybert使用得学习策略和蒸馏方式做的消融实验:

下面是作者针对学生层到老师层的映射做的消融实验:

可以看到,使用均匀映射的效果是最好的,同时,作者也表明,对于一个下游任务,自适应的选择层数是一个具有挑战性的问题,也是未来的工作方向。

相关思考


该论文通过蒸馏方式实现对模型的压缩,整体上的实现分为以下几步:
  • 预训练阶段的蒸馏
  • 数据增强
  • 微调阶段的蒸馏

每个蒸馏,又会进行以下操作:

  • embedding层的蒸馏
  • hidden层的蒸馏
  • attention的蒸馏
  • 预测层的蒸馏

之所以做数据增强,是为了在对具体任务蒸馏时,扩充学生模型的训练集,提高学生模型的表现。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值