知识蒸馏简单介绍

知识蒸馏简单介绍

在自然语言领域中,自bert伊始,预训练模型变得越来越大。大的预训练模型在离线测评时通常有着良好的结果,但是对于线上的时效要求往往难以满足。因此,发展出了权重剪枝(Weight Pruning),量化(Quantization),知识蒸馏(Knowledge Distillation)这些技术来减少模型的大小,提高模型的运行效率。

本文将从模型训练和损失函数两个角度去简单介绍下几种不同的知识蒸馏方法。

知识蒸馏是指将知识从一个复杂的模型压缩进入一个更小的模型的方法。

损失函数

logits

在看Distilling the Knowledge in a Neural Network时,论文中指出Caruana在Model compression对于模型压缩使用集成模型的softmax的logits,使用均方误差进行小模型学习,通过这种方式避免hard targets带来的信息损失。

Caruana and his collaborators circumvent this problem by using the logits (the inputs to the final softmax) rather than the probabilities produced by the softmax as the targets for learning the small model and they minimize the squared difference between the logits produced by the cumbersome model and the logits produced by the small model.

可能因为阅读不仔细,笔者在阅读Caruana的Model compression中并未发现相关内容的描述。在Model compression中虽然标题在讲模型压缩,但主要内容是在讲如何生成伪数据用于集成模型打标,并在论文中提出了MUNGE方法进行伪数据生成,使用这些伪数据,可以训练得到超过原有神经网络模型(论文中使用ann作为小模型学习集成模型的知识)的效果。

softmax with temperature

Distilling the Knowledge in a Neural Network中使用了带有temperature的softmax,因为文中认为在大模型中,会对正确答案产生很高的置信度,其他错误类别的概率很低,这样会导致中间信息的缺失。因为非正确类别的值特别小,接近于0.

本文中使用带有温度的softmax来对原生的softmax函数进行平滑,其中T表示温度(temperature),通常设置为1,T设置为较高的值将有助于在类间得到更平滑的概率分布。
q i = e x p ( z i / T ) ∑ j e x p ( z j / T ) q_i=\frac{exp(z_i/T)}{\sum_jexp(z_j/T)} qi=jexp(zj/T)exp(zi/T)
如果用于训练蒸馏模型的数据集正确标签已知,那么文中使用下面的损失函数进行模型训练:
L K D ( W S ) = α H ( y t r u e , P S ) + ( 1 − α ) H ( P T τ , P S τ ) L_{KD}(W_S) = \alpha H(y_{true},P_S) + (1-\alpha) H(P^τ_T,P^τ_S) LKD(WS)=αH(ytrue,PS)+(1α)H(PTτ,PSτ)
上式中 H H H表示交叉熵函数, P S P_S PS表示学生的soft label,温度参数为1, P T τ P^τ_T PTτ P S τ P^τ_S PSτ分别表示教师/学生的soft label,同时两者有着相同的温度参数T。 W s W_s Ws为学生模型的权重, L K D L_{KD} LKD为知识蒸馏的损失函数。原文表示当 α < 1 − α \alpha<1-\alpha α<1α时会取得最优的结果。

**注意:**soft label的损失计算时梯度缩小了 T 2 T^2 T2倍,所以在同时计算hard label和soft label时需要对soft label乘 T 2 T^2 T2,因此在设置(1- α \alpha α)中需要注意这个问题。

softmax with KL divergence

Knowledge Distillation via Route Constrained Optimization中KD损失函数中的教师模型的softmax和学生模型的softmax的使用KL散度(KullbackLeibler divergence)计算。
L K D = H ( y , P s ) + λ K L ( P t τ , P s τ ) L_{KD} = H(y, P_s) + λKL(P^τ_t, P^τ_s) LKD=H(y,Ps)+λKL(Ptτ,Psτ)

hint loss

FITNETS: HINTS FOR THIN DEEP NETS中使用两阶段的训练方式训练学生模型,同时引入了新的训练目标hint loss。hint loss具体如下:
L H T ( W G u i d e d , W r ) = 1 2 ∣ ∣ u h ( x ; W H i n t ) − r ( v g ( x ; W G u i d e d ) ; W r ) ∣ ∣ 2 L_{HT}(W_{Guided},W_r) = \frac{1}{2}||u_h(x;W_{Hint}) − r(v_g(x;W_{Guided});W_r)||^2 LHT(WGuided,Wr)=21∣∣uh(x;WHint)r(vg(x;WGuided);Wr)2

其中 u h u_h uh v g v_g vg分别表示教师/学生模型的嵌套函数,具体为使用了参数 W H i n t / W G u i d e W_{Hint}/W_{Guide} WHint/WGuide的Hint/Guide Layer.

Relative dissimilarity

在Learning from Multiple Teacher Networks中认为使用MSE来监督学生模型中间层学习教师模型中间层一方面由于是学习精确值,可能对学生网络过于严苛,另外一方面对有多个教师模型的情况难以适用。因此本文中提出使用相对距离的方式来监督学生模型学习。即只要学生模型对样本的相对距离判断和教师模型一致即可。

在监督学生模型时使用下面的损失函数:
L R D ( w S ; x i , x i + , x i − ) = m a x ( 0 , d ( p i , p i + ) − d ( p i , p i − ) + δ ) L_{RD}(w_S;x_i,x^+_i,x^-_i)=max(0,d(p_i,p_i^+)-d(p_i,p^-_i)+\delta) LRD(wS;xi,xi+,xi)=max(0,d(pi,pi+)d(pi,pi)+δ)
其中d表示距离矩阵,如 d S i j = d ( p i , p j ) , d T i j = d ( q i , q j ) d^{ij}_S=d(p_i,p_j),d^{ij}_T=d(q_i,q_j) dSij=d(pi,pj),dTij=d(qi,qj),q表示教师模型中间层对于输入x的向量表达,p表示学生模型中间层对于输入x的向量表达。对于一个三元输入 ( x i , x i 1 , x i 2 ) (x_i,x_{i1},x_{i2}) (xi,xi1,xi2)根据教师模型的相对距离得到学生模型的正负例标签:
p i 1 = { p i + , if   d ( q i , q i 1 ) < d ( q i , q i 2 ) p i − , o t h e r w i s e \begin{equation} p_{i1} = \begin{cases} p_i^+, &\text{if}\ \ d(q_i,q_{i1})<d(q_i,q_{i2}) \\ p_i^-, &otherwise \end{cases} \end{equation} pi1={pi+,pi,if  d(qi,qi1)<d(qi,qi2)otherwise
在论文中因为使用多个模型,所以对于相对距离的判断可能存在冲突,可以使用投票法来确定最佳的距离排列顺序。

模型训练

训练方案一:离线soft label蒸馏

在这里插入图片描述

  1. 使用标签数据训练teacher模型

  2. 使用训练好的teacher模型,在未标注或者有标注的数据上预测,得到soft label。

  3. 对学生模型使用soft label进行训练,如果训练数据存在真实标签,使用soft label和True lalel结合进行训练

训练方案二:分步训练hint layer、soft label

在这里插入图片描述

  1. 使用标签数据训练teacher模型
  2. 使用teacher的中间层去教导student的中间层,通过 L H T ( W G u i d e d , W r ) = 1 2 ∣ ∣ u h ( x ; W H i n t ) − r ( v g ( x ; W G u i d e d ) ; W r ) ∣ ∣ 2 L_{HT}(W_{Guided},W_r) = \frac{1}{2}||u_h(x;W_{Hint}) − r(v_g(x;W_{Guided});W_r)||^2 LHT(WGuided,Wr)=21∣∣uh(x;WHint)r(vg(x;WGuided);Wr)2进行优化。当student的中间层尺寸小于teacher的中间层尺寸可以对student的中间层进行线性变换,使得尺寸一致。
  3. 使用teacher的soft label去教导student的soft label,训练时训练整个student网络,使用公式 L K D ( W S ) = H ( y t r u e , P S ) + λ H ( P T τ , P S τ ) L_{KD}(W_S) = H(y_{true},P_S) + \lambda H(P^τ_T,P^τ_S) LKD(WS)=H(ytrue,PS)+λH(PTτ,PSτ),其中 λ \lambda λ会随着训练逐渐衰减,具体见[附录1](# 附录1)
训练方案三:多步锚点学习
  1. 使用标签数据训练教师模型,在训练教师模型时按照固定epochs间隔数进行模型保存,得到教师锚点模型(Knowledge Distillation via Route Constrained Optimization中使用固定epoch间隔数,个人感觉可以根据教师模型的loss进行保存,根据loss差值等差或等比的方式保存教师模型),也就是训练过程中模型的中间状态。
  2. 将学生模型使用教师模型的锚点模型进行学习,损失函数为[softmax with KL divergence](#softmax with KL divergence)。在训练时学生模型的权重保留,教师模型的权重在学生完成一次完整的训练后替换为epochs更大的锚点模型。
训练方案四:单步锚点学习

在训练方案三中因为对学生模型使用多个锚点模型分别训练,导致学生模型的训练会使用方案一中几倍的训练时间(具体倍数等于训练中所用的锚点数),针对此缺点,提出单步的锚点训练。此方案和方案三的最大区别在于方案三中为学生模型完成训练(等同于方案一中的student的步骤三)才进行教师模型的权重替换,而在此方案则在训练过程中即进行教师的权重替换。

  1. 使用标签数据训练教师模型,训练时按照固定epoch间隔数进行教师模型保存。如使用间隔数60,总训练epochs为240,得到4个锚点模型,分别为epochs为60,120,180,240时的教师模型。

  2. 初始化学生模型,使用锚点模型对训练集推断得到soft label。在学生模型进行前60epoch训练时用锚点60的教师模型进行监督学习,在学生模型进行60-120epochs训练时使用120锚点的教师模型,以此类推,最后使用240锚点的教师模型监督180-240epochs的学生模型。监督学习损失函数为[softmax with KL divergence](#softmax with KL divergence)。

训练方案五:贪婪策略的锚点学习

贪婪策略的锚点学习和方案三类似,区别在于方案三使用固定epoch间隔,而贪婪策略根据计算结果,为每次训练选取不同的epoch间隔进行学习。

在下文中 ϕ s \phi_s ϕs表示学生模型网络, ϕ t \phi_t ϕt表示教师模型网络。 X ′ X' X表示验证集, r i j r_{ij} rij表示j epochs的教师模型对于已被i epochs的教师模型指导过的学生模型的学习难度。

  1. 使用标签数据训练教师模型,训练时使用固定的epoch间隔数进行教师模型进行保存,为了保证效果,可以使用间隔数1.
  2. 初始化学生模型,使用锚点10的模型进行初始监督学习。
  3. 将学生模型 W s W_s Ws和教师锚点模型 W t i W_{t_i} Wti进行计算 r i j = H j − H i H i r_{ij}=\frac{H_j − H_i}{H_i} rij=HiHjHi,其中 H i = H ( ϕ s ( X ′ , W s ) , ϕ t ( X ′ , W t i ) ) H_i=H(\phi_s(X',W_s),\phi_t(X',W_{t_i})) Hi=H(ϕs(X,Ws),ϕt(X,Wti)), H j = H ( ϕ s ( X ′ , W s ) , ϕ t ( X ′ , W t j ) ) H_j=H(\phi_s(X',W_s),\phi_t(X',W_{t_j})) Hj=H(ϕs(X,Ws),ϕt(X,Wtj)),当 r i j r_{ij} rij的值大于阈值 δ \delta δ时,j++,然后重新计算新的j的锚点模型和学生模型的 r i j r_{ij} rij,否则返回j-1对应的锚点模型进行学生模型的训练。
  4. 重复步骤三,直到j==N,使用N的锚点模型进行学生模型的最后一次训练,训练完成后结束。
训练方案六:联合训练soft label蒸馏

本方案整体和方案四类似,不过训练时教师模型和学生模型联合在一起训练,每次训练一次教师模型,再训练一次学生模型,学生模型学习完的soft loss会再反馈给教师模型,让教师模型知道如何指导学生是合适的,并且可以微弱提升教师的性能。

  1. 将教师和学生模型放在一起进行初始化(keras中实现为定义三个模型,教师模型,学生模型,mix模型)
  2. 使用数据集的hard label训练一次教师,然后使用教师的soft label结合hard label训练一次mix模型。通过单步锚点的方式训练三个模型。
训练方案七:联合训练soft label蒸馏,loss相加

方案六是使用串行的方式联合训练教师模型和学生模型,此方案使用并行的方式联合训练教师和学生模型,通过将多个损失相加的方式进行模型训练。

  1. 初始化教师模型和学生模型
  2. 使用数据集进行教师和学生的联合训练,损失函数为 L = α H ( y t r u e , P S ) + β H ( P T τ , P S τ ) + σ H ( y t r u e , P T ) L=\alpha H(y_{true},P_S) + \beta H(P^τ_T,P^τ_S)+\sigma H(y_{true},P_T) L=αH(ytrue,PS)+βH(PTτ,PSτ)+σH(ytrue,PT)
训练方案八:多教师蒸馏

在这里插入图片描述

  1. 设置不同的seed,使用标签数据训练教师模型,得到多个不同的教师模型
  2. 使用训练好的多个不同的教师模型,在未标注或者有标注的数据上预测,得到soft label。
  3. 对学生模型使用target label+soft label+relative dissmiarity三者结合进行学习。具体损失函数为 L ( θ S ) = ∑ [ H ( y i , N S ( x i ) ) ] + α H ( 1 m ∑ t = 1 m N T t τ ( x i ) , N S τ ( x i ) ) ] + β L R D ( w S ; x i , x i + , x i − ) L(\theta_S)=\sum[H(y_i,N_S(x_i))]+\alpha H(\frac{1}{m}\sum_{t=1}^mN^{\tau}_{T_t}(x_i),N_S^{\tau}(x_i))]+\beta L_{RD}(w_S;x_i,x^+_i,x^-_i) L(θS)=[H(yi,NS(xi))]+αH(m1t=1mNTtτ(xi),NSτ(xi))]+βLRD(wS;xi,xi+,xi)。第一部分为学生模型输出概率和数据标签的交叉熵,第二部分为学生模型和教师模型输出层的交叉熵,第三部分为学生模型和教师模型中间层的相对关系损失。

在训练时使用在线生成triplets的方法,但是Learning from Multiple Teacher Networks中在此处引用的FaceNet的triplet中的距离是基于单个模型的L2距离进行计算困难样本,但是本例中有多个教师模型需要计算样本,对于困难样本的选择没有详细介绍。个人拙见为对多个教师模型计算的L2距离取平均值来计算困难样本。

训练方案九:基于替换的模型蒸馏(Progressive Module Replacing)

在此方法中因为训练方法不同,将之前的教师模型和学生模型称为predecessor模型和successor模型。核心思路是将predecessor模型的模块一一生成successor模块,训练时随机使用successor模块替换对应的predecessor模型的模块。最后推断时使用全部的successor模块组成successor模型用于推断。

  1. 初始化predecessor模型 P = { p r d 1 , p r d 2 , . . . , p r d n } P=\{prd_1,prd_2,...,prd_n\} P={prd1,prd2,...,prdn}中每个predecessor每个模块 p r d i prd_i prdi的对应替代模块 s c c i scc_i scci。(对于替代模块可以考虑使用原有模型中的权重进行初始化,在BERT-of-Theseus: Compressing BERT by Progressive Module Replacing中使用base-bert中的前六层网络权重用于successor模块的初始化。)
  2. 对每个predecessor模块使用随机概率p去替换为successor模块,具体为 r i + 1 ∼ B e r n o u l l i ( p ) r_{i+1}\sim Bernoulli(p) ri+1Bernoulli(p), r i + 1 ∈ { 0 , 1 } r_{i+1}\in \{0,1\} ri+1{0,1}, y i + 1 = s c c i ( y i ) ∗ r i + 1 + p r d i ∗ ( 1 − r i + 1 ) y_{i+1}=scc_i(y_i)* r_{i+1}+prd_i*(1-r_{i+1}) yi+1=scci(yi)ri+1+prdi(1ri+1)(*为按位元素相乘)。训练过程中使用固定随机概率p=0.5或者线性增加的概率 p d = m i n ( 1 , θ ( t ) ) = m i n ( 1 , k t + b ) p_d=min(1,\theta(t))=min(1,kt+b) pd=min(1,θ(t))=min(1,kt+b)。损失函数为 L = − ∑ j ∈ ∣ X ∣ ∑ c ∈ C [ I [ z j = c ] ⋅ log ⁡ P ( z j = c ∣ x j ) ] L=-\sum\limits_{j\in |X|}\sum\limits_{c\in C}[\mathbb{I}[z_j=c]\cdot \log P(z_j=c|x_j)] L=jXcC[I[zj=c]logP(zj=cxj)](其中 x j ∈ X x_j\in X xjX是第j个训练样本采样, z j z_j zj为样本真实标签, c 和 C c和C cC分别表示样本预测标签和样本标签集合)。
  3. 使用所有的替代模块进行fine-tune,训练模型直到模型不再提升。提取所有替代模块得到successor模型推测。

**注意:**在训练的第2步中predecessor模块的权重全程冻结,successor和predecessor的embedding层和output层的权重冻结。

此方案在使用bert模型进行中文nlp任务替换时似乎效果并不明显,在6层时BERT-of-Theseus和finetune效果相差不多,在3层时才会产生较多差距。

具体见https://zhuanlan.zhihu.com/p/157899766和https://spaces.ac.cn/archives/7575结论部分内容。

工具介绍

EDL

EDL

通常的模型蒸馏需要教师模型提前完成完成数据预测或者在训练学生模型时实时预测,方法一导致硬盘占用、方法二因为需要学生模型等待教师模型预测,训练效率低。
EDL通过将教师模型部署为服务,最大限度的提高模型训练效率。

总结

知识蒸馏从最开始用复杂模型对未标注数据进行伪标注开始,始终贯彻课程式学习(Curriculum Learning)由易到难的思想。

在训练目标上先是将onehot标签改进为使用logits,后来又因为这样目标标签的logits值太高改用成更平滑的带有温度参数T的softmax标签,之后又因为使用欧氏距离约束过于严格而改用triplets loss这种软约束。

在训练过程中则通过空间和时间两个方面来降低模型的学习难度。在空间维度上使用网络的中间层来进行初步训练,再去使用模型最终输出进行训练;在时间维度选择学习网络的中间态,通过跟随教师模型的训练过程,让学生模型可以一点一点的跟随教师模型进行学习。

BERT-of-Theseus: Compressing BERT by Progressive Module Replacing的模块替换感觉本质上还是对中间层的替换,只不过其使用了更多的中间层,另外从训练技巧上使用替换而不是给每个学生中间层输出加上对应教师中间层输出约束。

感觉后续知识蒸馏的研究方向可以从使用的超参理论解释入手或者从训练目标(损失函数)考虑,在训练方法上感觉很难再找出比较好的创新点了。

附录

附录1
def _apply_lambda_teach(self, algorithm): 
        """Updates the teacher weight on algorithm based on the epochs elapsed."""
        if not self._initialized:
            self._init_lambda_teach = algorithm.cost.lambda_teach.get_value()
            self._step = ((self._init_lambda_teach - self.final_lambda_teach) /
                          (self.saturate - self.start + 1))
            self._initialized = True
        algorithm.cost.lambda_teach.set_value(np.cast[config.floatX](
            self.current_lambda_teach()))

def current_lambda_teach(self):
    """
    Returns the teacher weight currently desired by the decay schedule.
    """
    if self._count >= self.start:
        if self._count < self.saturate:
            new_lambda_teach = self._init_lambda_teach - self._step * (self._count
                    - self.start + 1)
        else:
            new_lambda_teach = self.final_lambda_teach
    else:
        new_lambda_teach = self._init_lambda_teach

    if new_lambda_teach < 0:
  	new_lambda_teach = 0

	return new_lambda_teach

参考

Model compression

Distilling the Knowledge in a Neural Network

FITNETS: HINTS FOR THIN DEEP NETS

Knowledge Distillation via Route Constrained Optimization

BERT 蒸馏在垃圾舆情识别中的探索

Learning from Multiple Teacher Networks

TextBrewer: An Open-Source Knowledge Distillation Toolkit for Natural Language Processing

BERT-of-Theseus: Compressing BERT by Progressive Module Replacing

BERT-of-Theseus:基于模块替换的模型压缩方法

关于"bert-of-theseus"一文的解读说明

TinyBERT:Distilling BERT for Natural Language Understanding

ERNIE-Tiny

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值