基于keras的知识蒸馏(Knowledge Distillation)-分类与回归

本文介绍了知识蒸馏的概念,即通过训练小型学生模型来复制大型教师模型的行为。在TensorFlow和Keras中,通过定制Distiller类,利用损失函数和温度参数来软化概率分布,从而将教师模型的知识转移给学生模型。实验中,教师和学生模型均为卷积神经网络,在MNIST数据集上进行训练和评估,展示了知识蒸馏能提高学生模型的性能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Knowledge Distillation

Introduction to Knowledge Distillation

知识提取是一种模型压缩过程,其中对小(学生)模型进行训练,以匹配预先训练的大(教师)模型。通过最小化损失函数,将知识从教师模型转移到学生身上,目的是匹配软化的教师逻辑和基本事实标签。

通过在softmax中应用“温度”标度函数来软化logits,有效地平滑了概率分布,并揭示了教师学习到的课堂间关系。

Hinton et al. (2015)

导入基础库

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

构造Distiller类

自定义Distiller()类覆盖Model方法train_step、test_step和compile()。为使用蒸馏器,我们需要:

  • 训练有素的教师模型
  • 要训练的学生模型
  • 关于学生预测和基本事实之间差异的学生损失函数
  • 关于学生软预测和教师软标签之间差异的蒸馏损失函数以及温度
  • 衡量学生体重和蒸馏损失的阿尔法因素
  • 针对学生的优化器和(可选)评估绩效的指标

在train_step方法中,我们执行教师和学生的前向传递,分别通过α和1-alpha对student_loss和distraction_loss进行加权来计算损失,并执行后向传递。注意:只有学生权重会更新,因此我们只计算学生权重的梯度。

在test_step方法中,我们在提供的数据集上评估学生模型。

class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        """ Configure the distiller.

        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值