NLP中知识蒸馏

NLP中的知识蒸馏

一、什么是知识蒸馏

知识蒸馏一个重要目的是让学生模型学习到老师模型的泛化能力,让轻量级的学生模型也可以具备重量级老师模型的几乎同样的能力。
一个很高效的蒸馏方法就是使用老师网络softmax层输出的类别概率来作为软标签,和学生网络的softmax输出做交叉熵。
传统训练方法是硬标签,正类是1,其他所有负类都是0。但知识蒸馏的训练过程过程是用老师模型的类别概率作为软标签。

二、为什么需要知识蒸馏

大模型虽然效果很好,但模型较重推理速度太慢无法瞒足工业要求,而小模型轻,推理速度快,但是直接使用数据训练效果较差,知识蒸馏就是想让小模型在拥有较快的推理速度下,也具备大模型的泛化能力。

三、知识蒸馏中的SoftMax

原始的softmax:

q i = e x p ( z i ) ∑ j e x p ( z j ) q_i = {\frac{exp(z_i)}{{\sum_{j}{exp(z_j)}}}} qi=jexp(zj)exp(zi)

上述有说到,知识蒸馏是student模型学习tearch模型的软标签,但是如果

  1. 教师模型softmax输出的软标签概率分布熵很小,就是负标签的概率值接近于0,学生模型负标签知识学习很弱,那么该值对损失函数的影响会很小
  2. 如果两个logits差异较大,而使用softmax之后会导致两个输出的差异降低,从而减少了模型的知识,而使用softmax-T后,会更好的保留两者之间的差异。
    所以对softmax加了温度:

q i = e x p ( z i / T ) ∑ j e x p ( z j / T ) q_i = {\frac{exp(z_i/T)}{{\sum_{j}{exp(z_j/T)}}}} qi=jexp(zj/T)exp(zi/T)

根据公式可以看出,让T越大时,softmax输出值越平滑,输出值得熵越大,会放大负标签携带的信息,模型会相对校对的关注负标签,能够充分的学习。 一般来讲T会大于1;

四、如何选择温度:

说白了温度的高低改变的是学生网络对负标签的关注程度

  • 温度较低时,负类别携带的信息会被相对减少,对负类别的关注较少,负类别的概率越低,关注越少。
  • 温度较高时,负类别的概率值会相对增大,负类别携带的信息会被相对地放大,学生网络会更多关注到负标签。

实际上,负类别中包含一定的信息,尤其是那些概率值较高的负类别。 但由于老师网络的负类别可能会有噪声,并且负类别的概率值越低,其信息就越不可靠。因此温度的选取比较看经验,本质上就是在下面两件事之中取舍

  1. 从负类别中获取信息 --> 温度要高一些
  2. 防止受负类别中噪声的影响 --> 温度要低一些

总的来说,温度的选择和学生网络的大小有关,学生网络参数量比较小的时候,相对比较低的温度就可以了,因为参数量小的模型不能捕获所有知识,所以可以适当忽略掉一些负标签的信息。

五、如何蒸馏、LOSS是什么样的

  1. 如何蒸馏:
    第一步是训练老师网络;第二步是蒸馏老师网络的知识到学生网络。
  2. 损失函数:
    高温蒸馏过程的目标函数由distill loss(对应软标签)和student loss(对应硬标签)加权得到。
IMG20220919161545

L = α L s o f t + β L h a r d L = \alpha L_{soft} + \beta L_{hard} L=αLsoft+βLhard

distill loss(对应软标签) :
是老师模型softmax经过高温后输出的概率分布和学生网络在同等温度下的概率分布做交叉熵, 软标签 loss:

L s o f t = − ∑ j = 1 n p j T l o g ( q j T ) ,其中 p i T = e x p ( v i / T ) ∑ k = 1 n e x p ( v k / T ) , q i T = e x p ( v i / T ) ∑ k = 1 n e x p ( v k / T ) L_{soft} = -{\sum_{j=1}^n {p_j^T log(q_j^T)}},其中 p_i^T = {\frac{exp(v_i/T)}{{\sum_{k=1}^n exp(v_k/T)}}}, q_i^T = {\frac{exp(v_i/T)}{\sum_{k=1}^n {exp(v_k/T)}}} Lsoft=j=1npjTlog(qjT),其中piT=k=1nexp(vk/T)exp(vi/T)qiT=k=1nexp(vk/T)exp(vi/T)

student loss(对应硬标签) :
是学生网络在温度为1下的概率分布和真实标签做交叉熵,硬标签 loss:

L h a r d = − ∑ j n c j 1 l o g ( q j 1 ) ,其中 q j 1 = e x p ( v j ) ∑ k n e x p ( v k ) L_{hard} = -{\sum_j^n {c_j^1 log(q_j^1)}}, 其中 q_j^1 = {\frac{exp(v_j)}{\sum_{k}^n {exp(v_k)}}} Lhard=jncj1log(qj1),其中qj1=knexp(vk)exp(vj)

六、项目开展和算法调优过程

损失函数的比较和选择
  • 交叉熵损失(CrossEntropyLoss):基于softmax-T计算损失。其中softmax-T上述有过介绍,不在过多赘述。
  • 均方差损失(MESLoss):基于logits直接计算。
    在我的实验中,两者之间的训练结果并无太大差异,反而MSELoss计算方法获得的结果更优。 (其实是近MSELoss,但大多数实验者直接用MSELoss替代)基本类似。
    使用MSELoss的另一个好处是,避免了超参数T的使用。 超参数T的使用还会影响soft-loss和hard-loss的比重,虽然理论上需要给soft-loss乘以 T 2 T^2 T2 ,让彼此的权重在同一个数量级上。
    对于知识蒸馏建议使用MSELoss,而非使用原本的softmax-T-loss(Hinton,2014),能达到更好的效果,理论和实验都有证明。
  • 项目中的具体做法:
  • 这里我使用了一种soft-label的方法。是将teacher模型的logits表示经过softmax后,与one-hot表征的实际label进行相加, 注意这里引入相加的权重alpha,实验做好的值为0.5,alpha越大越依赖教师模型的logits。然后得到一个新的label表示。如teacher-logit-softmax = [0.2,0.7,0.1],实际标签one-hot = [0,1,0],alpha = 0.5,那么最后的label = [0.1,0.85,0.05],然后用这个label和student计算获得的logits进行MSELoss计算,求导。这种方法获得的结果和直接用MSELoss计算后,然后使用alpha权重相加结果类似,但好处是少了一次MSELoss的计算过程,在训练时,训练速度更快。
Teacher模型

因为是中文NLP任务,对于teacher模型选择的标准是,尽量好,尽量优秀,甚至可以使用集成学习的方法获得最优结果。
项目中使用了中文的Roberta-base模型作为teacher模型(已经对下游NER实体进行了Finetune,精度F1 = 94.67%),具体参数:Epoch=3,max_sequence_length = 256,batch_size = 32,model_size = 42.2M。

Student模型

前后使用多种Student模型,选取的条件是速度满足当前模型预测的速度要求,(不做蒸馏前,纯训练)精度越高越好。
将BERT模型蒸馏至TextCNN 和BiLSTM等小模型上,精度下降3%,速度提升400倍。注意文章使用了word2vec词向量,并非完全从头训练,具体细节可看论文和代码。
蒸馏学习的Student模型分别使用了ALBert-Base,ALBert-Tiny和ELECTRA-Small这三种模型,模型使用的alpha = 0.5,使用的是MSELoss的方法,具体结果:

ModelBase-f1Distill-f1M-SizeI-Time
Roberta-Base0.94675412M95.35s
AlBert-base0.8100.90942.2M27.3
ALBert-Tiny0.6120.82416.3M9.3s
ELECTRA-small0.91240.926749.4M24.8s
蒸馏中的数据增强

使用训练好的teacher模型对数据打标,形成伪标签,再训练student模型,即使部分case teacher没有标对,也没有很大的关系,目的就是让student更像teacher,本来badcase就很小,对训练影响度有限,但是伪标签数据不易过度,以免真正影响效果。

Batch-size和max-sequence-length的使用:
  1. 训练数据的平均长度=37 chars,一般使用2倍平均长度即可获取较好结果,这里使用的是128长度,能够涵盖大部分训练数据而不会导致数据流失。
  2. 对于标签不均衡的训练数据来说,扩大batch-size能比较好的覆盖更全的label,让模型能尽快找到更合适的训练方向。
多步蒸馏到超小模型

如上述实验中的奖RoBerta-Base模型内容蒸馏到ALBert-Tiny,模型的size差异大约在30倍,如果直接蒸馏,效果会不好。精度大约只能达到82.4%。这里可以借鉴miniLM(Ref-9)的一种操作Trick,间接蒸馏。具体做法是先将大模型(如:RoBerta-Base,94.7%)里的知识蒸馏到一个中(过渡)模型(如:ELECTRA-base,92.1%),然后再用中模型作为teacher,将知识蒸馏到真正的小模型(如此处的ALBert-Tiny),模型精度最终可以达到88.3%,精度大约有6个点的提升。

七、知识蒸馏需要注意的点

  • 温度T,高温T。通常模型训练的时候使用高温T,而在模型测试和预测阶段的时候,是不使用teacher模型的,仅使用student模型进行测试和预测,也就是T在预测阶段不使用。

  • MSELoss计算L-soft。不使用上述复杂的L-soft,而使用简单的均方差损失函数——MSELoss。

  • Hard-Loss加入模型。即便使用了Soflt-Loss,还是需要引入Hard-Loss以及超参数 目的是teacher模型也可能存在无法完全学对的可能,所以在数据质量有保证的情况下,引入学生模型的hard-loss能更好的学会teacher无法学会的知识。实际使用过程中也发现,引入hard-loss很有效果。

  • 使用更多损失函数。,Hinton的蒸馏学习使用的是Cross-Entorpy作为损失函数,其实损失函数不止于交叉熵损失函数,包括MSELoss,NLLLoss,HingeLoss等。实验中,我使用了MSELoss和CELoss做比较,发现二者对于Student模型的效果类似,所以对于不同下游任务,可以使用更贴合的Loss函数,不必局限于CELoss。但是对于蒸馏学习的理解一定要到位,才能更合理的利用Loss。

  • 集成学习加入到蒸馏学习中。通常我们不会仅仅使用一个老师,而是使用多个teacher,然后将多个teacher的知识权重相加引入到student模型中。这是将集成学习和蒸馏学习相融合,能让学生学到更多信息,但是也同时增加了模型训练的难度(增多了超参数的数量以及集成学习方法的比较),对于初学者不建议使用。

八、知识蒸馏的几个思考

student loss的必要性

因为老师网络也有一定的错误率,使用ground truth可以有效降低错误被传播给学生网络的可能。
举例,老师虽然学识远远超过学生,但是他仍然有出错的可能,而这时候如果学生在老师的教授之外,可以同时参考到标准答案,就可以有效地降低被老师偶尔的错误“带偏”的可能性。

为什么student loss(硬标签)所占比重比较小的时候,能产生最好的结果

这是一个经验的结论。一个可能的原因是,由于soft target产生的gradient与hard target产生的gradient之间有与 T 相关的比值。
在同时使用soft target和hard target的时候,需要在soft target之前乘上 T 2 T^2 T2 这个系数,这样才能保证soft target和hard target贡献的梯度量基本一致。

能不能直接match logits(不经过softmax)

直接match logits指的是,直接使用softmax层的输入logits(而不是输出)作为soft targets,需要最小化的目标函数是Net-T和Net-S的logits之间的平方差。直接上结论: 直接match logits的做法是 T→∞ 的情况下的特殊情形。

  • 12
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
一、设计目的 本课程设计旨在通过理论和实践相结合的方式,使学生掌握知识蒸馏的基本概念、原理和方法,培养学生的独立思考和解决实际问题的能力,提高学生的综合素质和实践能力。 二、设计内容 1. 知识蒸馏的基本概念和原理 - 知识蒸馏的概念和定义 - 知识蒸馏的原理和分类 - 知识蒸馏的应用领域和前景 2. 知识蒸馏的方法和技术 - 模型压缩方法 - 数据增强方法 - 策略裁剪方法 - 模型蒸馏方法 3. 知识蒸馏的实践应用 - 图像分类任务的知识蒸馏 - 语音识别任务的知识蒸馏 - 自然语言处理任务的知识蒸馏 4. 知识蒸馏的评价和优化 - 模型蒸馏的评价方法 - 知识蒸馏的效果和性能评价 - 知识蒸馏的优化方法 三、教学方法 1. 理论授课:通过讲授知识蒸馏的概念、原理、方法和应用等方面的内容,使学生掌握知识蒸馏的基本概念和实现方法。 2. 实践操作:通过实践操作,让学生亲自动手实现知识蒸馏,并通过实验结果评价知识蒸馏的效果和性能。 3. 讨论研究:通过讨论研究,让学生了解知识蒸馏的最新进展和应用领域,培养学生的独立思考和解决实际问题的能力。 四、实验项目 1. 图像分类任务的知识蒸馏实验 - 实验目的:通过实现图像分类任务的知识蒸馏,掌握知识蒸馏的基本方法和技术。 - 实验内容:使用CIFAR-10数据集进行实验,比较蒸馏前后的模型性能和效果差异。 - 实验要求:熟悉Python编程和深度学习框架,具备良好的实验设计和结果分析能力。 2. 语音识别任务的知识蒸馏实验 - 实验目的:通过实现语音识别任务的知识蒸馏,掌握知识蒸馏在语音处理领域的应用。 - 实验内容:使用TIMIT数据集进行实验,比较蒸馏前后的模型性能和效果差异。 - 实验要求:熟悉Python编程和深度学习框架,具备良好的实验设计和结果分析能力。 3. 自然语言处理任务的知识蒸馏实验 - 实验目的:通过实现自然语言处理任务的知识蒸馏,掌握知识蒸馏在自然语言处理领域的应用。 - 实验内容:使用IMDB数据集进行实验,比较蒸馏前后的模型性能和效果差异。 - 实验要求:熟悉Python编程和深度学习框架,具备良好的实验设计和结果分析能力。 五、评分标准 1. 实验报告(60%):包含实验设计、实验步骤、实验结果和分析等内容。 2. 实验操作(20%):包括实验操作流程、代码编写和实验调试等内容。 3. 学术论文阅读和报告(20%):包括阅读学术论文、撰写学术报告和答辩等内容。 六、参考教材 1.《深度学习基础教程》 2.《深度学习与自然语言处理》 3.《深度学习优化方法》 4.《深度学习模型压缩与蒸馏

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值