GNN Algorithms(8): Knowledge Distillation 知识蒸馏

知识蒸馏 knowledge distillation

核心:近似思想。student model -> 近似模拟teacher model。

目录

1. 知识蒸馏 Knoweldge Distillation

1.1 KD Concept

1.2 KD 基本算法思想

1.3 KD Loss: soft_loss & hard_loss

1.3.1 hard_loss

1.3.2 soft_loss 

1.3.3 KD_loss 

​1.4 KD 优点

1.5 KD application

2. Torch Implementation

2.1 KD 代码要点 

2.1.1 weight initialization

        2.1.2 log_softmax和softmax区别:

2.2 my implementation

Reference


1. 知识蒸馏 Knoweldge Distillation

1.1 KD Concept

知识蒸馏主要用于提取大型复杂模型中的知识,并传递给较小的模型这种技术有助于在不显著降低模型性能的前提下,减小模型的复杂度和计算资源需求

1.2 KD 基本算法思想

通过训练一个较小的student model,使其模仿一个较大的且更复杂的teacher model的输出。具体步骤如下:

1) Training a teacher model: 首先训练一个性能优异但可能复杂度较高的教师模型。

2) 获取软标签soft labels:使用teacher model对training data进行预测,获得软标签soft labels,这事概率分布,而不是单一的硬标签 hard labels。

3) Training a student model: 利用这些soft labels来训练较小的学生模型,使其学习teacher model输出的概率分布。

1.3 KD Loss: soft_loss & hard_loss

KD Loss: hard loss & soft loss. soft labels比hard labels包含更多的信息,因为它们反映了teacher model对不同类别的信息程度,这有助于student model更好地理解和泛化。

1) teacher loss: nn.CrossEntropyLoss()

2)student loss: KD Loss: soft_loss & hard_loss

1.3.1 hard_loss

hard loss: cross-entropy loss for (student_logits, labels)

1.3.2 soft_loss 

KLDivLoss for (log_softmax(student_logits/T), softmax(teaher_logits/Temperature)) = (soft_probs, soft_targets),即KL散度(Kullback-Leibler divergence),用于衡量两个概率分布之间差异的非对称性。

D_{KL}(P||Q) = \sum_i p_i log(\frac{p_i}{q_i})

pi是真实概率,qi是近似概率,log_softmax qi with temperature T: q_i = log\_so\!f\!t\!max(z_i/T)

1.3.3 KD_loss 

K\!D_{loss} = \alpha * T^2 * so\!f\!t\_los\!s + (1 - \alpha) * ha\!r\!d\_los\!s

1.4 KD 优点

  • 减少模型大小和复杂度: 可以在部署阶段使用较小的student model,减少存储和计算资源的需求。
  • 加快推理速度:较小的student model通常具有更快的推理速度,适合在实时或资源受限的环境中使用。
  • 保留性能:尽管模型简化了,但通过学习teacher model的知识,student model仍能保持较高的性能。

1.5 KD application

知识蒸馏的增量学习、模型压缩

incremental learning: 是指一个model能不断从新样本中学习新知识,并能保存大部分以前已经学习到的知识。

2. Torch Implementation

2.1 KD 代码要点 

2.1.1 weight initialization
  • torch.nn.init.normal_,正态分布random取值
  • torch.nn.xavier_normal_,glorot正态初始化,mean=0,std = \sqrt{2/(f\!an\_in + f\!an\_out)},fan_out是指输出神经元个数,glorot防止信号在ffd过程中逐渐放大或缩小,有助于减轻梯度消失或梯度爆炸。
2.1.2 log_softmax和softmax区别:
  • 分类问题一般用cross-entropy loss
  • 使用log_softmax: 一方面是为了解决溢出问题,另一方面是方便KL散度计算。

2.2 my implementation

my github link: https://github.com/yuyongsheng1990/Knowledge_Distillation_Models 

  • KD_Model_01: GPT_Generated_Codes.
  • KD_Model_02: PyTorch Tutorial.
  • KD_Model_03: textbrewer implemented knowledge distillation.

Reference

Knowledge Distillation Tutorial — PyTorch Tutorials 2.3.0+cu121 documentation

TextBrewer/examples/notebook_examples/msra_ner.ipynb at master · airaria/TextBrewer · GitHub

【NLP】(task7)Transformers完成序列标注任务-阿里云开发者社区

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

天狼啸月1990

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值