Knowledge Distillation
Introduction to Knowledge Distillation
知识提取是一种模型压缩过程,其中对小(学生)模型进行训练,以匹配预先训练的大(教师)模型。通过最小化损失函数,将知识从教师模型转移到学生身上,目的是匹配软化的教师逻辑和基本事实标签。
通过在softmax中应用“温度”标度函数来软化logits,有效地平滑了概率分布,并揭示了教师学习到的课堂间关系。
导入基础库
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
构造Distiller类
自定义Distiller()类覆盖Model方法train_step、test_step和compile()。为使用蒸馏器,我们需要:
- 训练有素的教师模型
- 要训练的学生模型
- 关于学生预测和基本事实之间差异的学生损失函数
- 关于学生软预测和教师软标签之间差异的蒸馏损失函数以及温度
- 衡量学生体重和蒸馏损失的阿尔法因素
- 针对学生的优化器和(可选)评估绩效的指标
在train_step方法中,我们执行教师和学生的前向传递,分别通过α和1-alpha对student_loss和distraction_loss进行加权来计算损失,并执行后向传递。注意:只有学生权重会更新,因此我们只计算学生权重的梯度。
在test_step方法中,我们在提供的数据集上评估学生模型。
class Distiller(keras.Model):
def __init__(self, student, teacher):
super().__init__()
self.teacher = teacher
self.student = student
def compile(
self,
optimizer,
metrics,
student_loss_fn,
distillation_loss_fn,
alpha=0.1,
temperature=3,
):
""" Configure the distiller.
Args:
optimizer: Keras optimizer for the student weights
metrics: Keras metrics for evaluation
student_loss_fn: Loss function of difference between