模型蒸馏探索(Bert)

1. 蒸馏是什么?

​  所谓的蒸馏,指的是从大模型(通常称为teacher model)中学习小模型(通常称为student model)。何以用这个名字呢?在化学中,蒸馏是一个有效的分离沸点不同的组分的方法,大致步骤是先升温使低沸点的组分汽化,然后降温冷凝,达到分离出目标物质的目的。那么,从大模型中,通过一定的技术手段,将原模型中的知识提取出来,这个过程很类似于物质分离,所以将其称为是蒸馏。

2. 蒸馏方法

2.1 Logit Distillation

 深度学习巨头Hinton提出,是一篇开创性的工作。

 其改进是针对softmax进行的改进:

q i = e x p ( z i / T ) ∑ j e x p ( z j / T ) q_{i}=\frac {exp(z_{i}/T)}{\sum_{j}exp(z_{j}/T)} qi=jexp(zj/T)exp(zi/T)

 其中的T是temperature,为设定的超参数。

计算流程

 最终的loss为:

L = ( 1 − α ) c r o s s _ e n t r o p h y ( y , p ) + α ∗ c r o s s _ e n t r o p h y ( q , p ) T 2 y : 真 实 l a b e l p : s t u d e n t   m o d e l 预 测 结 果 q : t e a c h e r   m o d e l 预 测 结 果 α : 蒸 馏 l o s s 权 重 因 为 求 梯 度 的 时 候 会 新 的 目 标 函 数 会 导 致 梯 度 是 以 前 的 1 T 2 , 所 以 要 再 乘 上 T 2 L=(1-\alpha)cross\_entrophy(y,p)+\alpha *cross\_entrophy(q,p)T^{2} \\ y:真实label\\p:student\ model预测结果\\q:teacher\ model预测结果 \\ \alpha:蒸馏loss权重 \\因为求梯度的时候会新的目标函数会导致梯度是以前的\frac {1}{T^2},所以要再乘上T^{2} L=(1α)cross_entrophy(y,p)+αcross_entrophy(q,p)T2ylabelp:student modelq:teacher modelα:lossT21T2

 这个改进的motivation有一下几点:

  • softmax函数自身对数据分布敏感

 对于相同的logits,当采用不同的temperature的时候,softmax之后的分布变化较大,温度越大,分布越平缓,结果的区分度越低,相当于增大了学习的难度,以后做inference的时候,temperature=1,分类结果会得到较好的提升。

  • soft prediction本身带有额外的信息

soft prediction代表teacher model对不同类别的识别概率,这个概率分布本身就带有一定的信息的,比如预测轿车的时候,识别为垃圾车和胡萝卜的概率可能都比较低,但是识别为垃圾车的概率显然要比识别为胡萝卜更高,这个信息说明垃圾车本身相比于胡萝卜与轿车的相关性更高。

 这里有人可能会好奇,为何需要先训练teacher model,然后再蒸馏到student model上面?为何不能直接训练student model?

 要注意的是,蒸馏的核心思想是好的模型不是为了拟合训练数据,而是学习如何泛化到新的数据,所以蒸馏到目的是为了让学生模型学习到教师模型的泛化能力。单纯训练学生模型的话,因为模型比较简单,所以训练难度也更大,其训练出的模型的泛化能力大概率也不如教师模型强大。

 另外注意,模型蒸馏是一种思想,理解了这篇文章的思想,可以泛化到后续的许多模型中去,因为蒸馏的使用其实本质就是各种loss function的设计。

2.2 Distilled BiLSTM

 这篇文章在性能方面完全不存在竞争力,在transformer满天飞的年代,其蒸馏的结果仅仅是获得了ELMo级别的性能,不过,这篇文章最大的亮点是,在ELMo性能级别下,其使用的参数少了大约100倍,推理时间少了15倍,这对于资源敏感类任务来说可谓是一个巨大的诱惑。

Distilled BiLSTM两类任务,其teacher model使用的是BERT-large,student model为BiLSTM+Relu

 注意FIgure 2中的d操作,对于两个句子向量,其操作为: f ( h s 1 , h s 2 ) = [ h s 1 , h s 2 , h s 1 ⊙ h s 2 , ∥ h s 1 − h s 2 ∥ ] , ⊙ f(h_{s1},h_{s2})=[h_{s1},h_{s2},h_{s1}\odot h_{s2},\|h_{s1}-h_{s2}\|],\odot f(hs1,hs2)=[hs1,hs2,hs1hs2,hs1hs2],代表 elementwise multiplication.

 上损失函数:

KaTeX parse error: No such environment: equation at position 8: \begin{̲e̲q̲u̲a̲t̲i̲o̲n̲}̲\begin{aligned}…

 其中, z B , z S z^{B},z^{S} zB,zS分别为teacherstudentlogits,即预测值, t i t_{i} ti为真实one-hot类别向量 t t t为第i个元素,对于无标签元素, t i = 1   i f   i = a r g m a x y B   e l s e   0 t_{i}=1\ if\ i=argmaxy^{B}\ else\ 0 ti=1 if i=argmaxyB else 0

 论文中作者还提出了一些nlp领域的数据增强技术,可以看原文了解一下。

2.3 DistilBERT

这篇文章没啥难理解的地方,记录一下就行了。

 效果:模型尺寸降低40%,保留97%的泛化能力,提升了60%的速度。

 模型:teacher model为标准的Bertstudent modellayers=teacher model layers/2Bert,从teacher modellayers中每隔2层取一层初始化student modellayer

 损失函数:公式(1)的cross entropymasked language modeling loss,外加两模型的首层的隐状态的cos loss

2.4 BERT-PKD

PKD(patient knowledge distill),其teacher model为标准BERT,而student model也是BERT,不过其堆叠的层数要少于teacher。先上图为敬:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-LrC5KmQ9-1632812843154)()]

两种架构.(左)PKD-skip:每2层学习一下teacher的输出.(右)PKD-last:学习teacher网络的最后六层输出。两种结果都有注意一个细节,就是最后的一层输出不参与PT loss,而是参与DS loss

 从架构图可以看出,相比于直接学习最终的输出,PKD方法还教导student model学习中间层的输出,last方法的先验假设是认为teacher modeltop layers包含最丰富的信息以便指导student model,而skip的先验假设则是认为teacher modellower layers也包含了需要被蒸馏的重要信息,从作者的结果来看,PKD-Skip 效果slightly better,作者认为PKD-Skip抓住了老师网络不同层的多样性信息。而PKD-Last抓住的更多相对来说同质化信息,因为集中在了最后几层。

 对于BERT类模型来说,由于其输入序列长度比较大,如果学习所有的tokens,不仅computationally expensive, 也可能introduce noise,又考虑到BERT的预测是只针对*“[CLS]" token的最后一层输出,所以如果student model可以获得teacher model[CLS]的表达能力,那么它就有了teacher model的泛化能力,所以直接学习[CLS]*。

 损失函数设计:

L P K D = ( 1 − α ) L C E s + α L D S + β L P T 上 标 s 代 表 s t u d e n t   m o d e l L_{PKD}=(1-\alpha)L_{CE}^{s}+\alpha L_{DS}+\beta L_{PT}\quad 上标s代表student\ model L

  • 3
    点赞
  • 38
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值