知识蒸馏的训练过程是通过结合硬标签损失(( L_{\text{hard}} ))和软标签损失(( L_{\text{soft}} ))进行反向传播,更新学生模型的参数。
具体流程如下:
-
前向传播:
- 教师模型 和 学生模型 分别对相同的输入数据进行前向传播,计算它们各自的输出。
- 教师模型的输出会生成“软标签”,即通过温度系数平滑过的类别概率分布。学生模型则输出它自己的类别概率分布。
-
计算损失:
- 硬标签损失(( L_{\text{hard}} )):这是学生模型的输出与真实标签之间的交叉熵损失,通常用于确保学生模型在最终的任务上取得好的性能。
- 软标签损失(( L_{\text{soft}} )):这是学生模型的输出概率分布与教师模型输出的“软标签”之间的差异,通常使用KL散度(Kullback-Leibler Divergence)来度量。通过软标签损失,学生模型能从教师模型的特征中学到更多细节信息。
损失函数通常是两者的加权和,公式如下:
[
L = \alpha \cdot L_{\text{hard}} + (1 - \alpha) \cdot L_{\text{soft}}
]- ( \alpha ) 是一个超参数,用来控制硬标签损失和软标签损失的相对权重。
- 温度系数 ( T ) 通常用于软化教师模型输出的概率分布,使其更加平滑,能提供更多类别之间的相关性信息。
-
反向传播:
- 计算出的总损失 ( L ) 会通过反向传播(Backpropagation)过程,更新学生模型的参数。
- 在反向传播过程中,损失函数的梯度会通过链式法则从输出层传回到模型的每一层,逐步调整模型参数,最终提升学生模型的表现。
-
迭代训练:
- 重复执行上述的前向传播、损失计算和反向传播,直到学生模型在训练集上达到期望的性能或者达到预设的训练轮数。
关键要点:
- 硬标签损失 ( L_{\text{hard}} ) 强调学生模型能够正确地学习真实标签,确保学生模型在任务上的准确性。
- 软标签损失 ( L_{\text{soft}} ) 则让学生模型学习教师模型的类别概率分布,使其能够捕捉更丰富的特征和类别之间的相关性。
- 反向传播 是通过计算出的总损失来更新学生模型的参数,最终优化学生模型的性能。
总结
知识蒸馏的反向传播过程是基于总损失函数,即硬标签损失和软标签损失的加权和。这个损失函数通过反向传播来优化学生模型的参数,使学生模型不仅能学习真实标签,还能从教师模型中吸收更多深层次的知识。