知识蒸馏 knowledge distillation
核心:近似思想。student model -> 近似模拟teacher model。
目录
1. 知识蒸馏 Knoweldge Distillation
1.3 KD Loss: soft_loss & hard_loss
1. 知识蒸馏 Knoweldge Distillation
1.1 KD Concept
知识蒸馏主要用于提取大型复杂模型中的知识,并传递给较小的模型。这种技术有助于在不显著降低模型性能的前提下,减小模型的复杂度和计算资源需求。
1.2 KD 基本算法思想
通过训练一个较小的student model,使其模仿一个较大的且更复杂的teacher model的输出。具体步骤如下:
1) Training a teacher model: 首先训练一个性能优异但可能复杂度较高的教师模型。
2) 获取软标签soft labels:使用teacher model对training data进行预测,获得软标签soft labels,这事概率分布,而不是单一的硬标签 hard labels。
3) Training a student model: 利用这些soft labels来训练较小的学生模型,使其学习teacher model输出的概率分布。
1.3 KD Loss: soft_loss & hard_loss
KD Loss: hard loss & soft loss. soft labels比hard labels包含更多的信息,因为它们反映了teacher model对不同类别的信息程度,这有助于student model更好地理解和泛化。
1) teacher loss: nn.CrossEntropyLoss()
2)student loss: KD Loss: soft_loss & hard_loss
1.3.1 hard_loss
hard loss: cross-entropy loss for (student_logits, labels).
1.3.2 soft_loss
KLDivLoss for (log_softmax(student_logits/T), softmax(teaher_logits/Temperature)) = (soft_probs, soft_targets),即KL散度(Kullback-Leibler divergence),用于衡量两个概率分布之间差异的非对称性。
pi是真实概率,qi是近似概率,log_softmax qi with temperature T: 。
1.3.3 KD_loss
1.4 KD 优点
- 减少模型大小和复杂度: 可以在部署阶段使用较小的student model,减少存储和计算资源的需求。
- 加快推理速度:较小的student model通常具有更快的推理速度,适合在实时或资源受限的环境中使用。
- 保留性能:尽管模型简化了,但通过学习teacher model的知识,student model仍能保持较高的性能。
1.5 KD application
知识蒸馏的增量学习、模型压缩
incremental learning: 是指一个model能不断从新样本中学习新知识,并能保存大部分以前已经学习到的知识。
2. Torch Implementation
2.1 KD 代码要点
2.1.1 weight initialization
- torch.nn.init.normal_,正态分布random取值
- torch.nn.xavier_normal_,glorot正态初始化,mean=0,
,fan_out是指输出神经元个数,glorot防止信号在ffd过程中逐渐放大或缩小,有助于减轻梯度消失或梯度爆炸。
2.1.2 log_softmax和softmax区别:
- 分类问题一般用cross-entropy loss
- 使用log_softmax: 一方面是为了解决溢出问题,另一方面是方便KL散度计算。
2.2 my implementation
my github link: https://github.com/yuyongsheng1990/Knowledge_Distillation_Models
- KD_Model_01: GPT_Generated_Codes.
- KD_Model_02: PyTorch Tutorial.
- KD_Model_03: textbrewer implemented knowledge distillation.
Reference
Knowledge Distillation Tutorial — PyTorch Tutorials 2.3.0+cu121 documentation
TextBrewer/examples/notebook_examples/msra_ner.ipynb at master · airaria/TextBrewer · GitHub