简单的知识蒸馏

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

import keras
from keras import layers
from keras import ops
import numpy as np

# 随着训练的进行,由于总损失(loss)是学生损失(student_loss)和蒸馏损失(distillation_loss)的加权和,
# 模型会同时考虑减少自身的预测损失(即提高预测准确性)和与教师模型预测分布的相似性。通过调整 alpha 参数,您
# 可以控制这两个目标之间的权衡。较大的 alpha 值将使模型更关注自身的预测准确性,而较小的 alpha 值则会使模型
# 更关注与教师模型预测分布的相似性。

# 知识蒸馏一般应该是从复杂的精度高的模型到简单的模型,让学生模型去学习教师模型的预测分布,但这个例子,因为简单的模型也
# 能达到不错的精度,所以没看出来性能提升

#教师模型一般应该是预训练模型,在高分辨率的图片数据集上训练过的,学生模型用来学习教师模型的预测概率分布
# 学生模型的结构和复杂性应该根据任务的要求、数据的特性以及资源限制来仔细选择。如果学生模型过于简单,它可能无法
# 捕捉到教师模型学习到的复杂模式和特征,导致性能不佳。另一方面,如果学生模型过于复杂,虽然它可能能够学习到更多
# 的细节和特征,但也可能导致过拟合,并且在计算资源上可能不高效。因此,在选择学生模型的结构时,需要进行权衡和实验。
# 一种常见的做法是从一个简单的模型开始,并逐步增加其复杂性,以观察性能如何变化。通过这种方式,可以找到在给定资源
# 和任务要求下性能最佳的学生模型结构。

#知识提炼器
class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher#教师模型
        self.student = student#学生模型
    #编译,保存一些优化器,损失函数,权重,温度等的参数
    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn#学生损失函数
        self.distillation_loss_fn = distillation_loss_fn#蒸馏损失函数
        self.alpha = alpha#蒸馏权重
        self.temperature = temperature#温度参数
    #计算损失
    def compute_loss(
        self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
    ):
        teacher_pred = self.teacher(x, training=False)#获取教师模型的预测
        student_loss = self.student_loss_fn(y, y_pred)#根据学生损失函数计算学生损失
        # 计算蒸馏损失。这通常涉及将教师模型和学生模型的预测都通过softmax函数(使用温度参数进行缩放),
        # 然后计算两者之间的差异。这里乘以(self.temperature**2)是一个常见的调整,用于平衡蒸馏损失。
        distillation_loss = self.distillation_loss_fn(
            ops.softmax(teacher_pred / self.temperature, axis=1),
            ops.softmax(y_pred / self.temperature, axis=1),
        ) * (self.temperature**2)
        # 根据alpha参数,将学生损失和蒸馏损失组合成一个总损失
        loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
        return loss
    def call(self, x):
        return self.student(x)

#教师模型比较大,学生模型比较小
teacher = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),#(14,14,256)
        layers.LeakyReLU(negative_slope=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),#(14,14,256)
        layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),#(7,7,512)
        layers.Flatten(),#(7*7*512)
        layers.Dense(10),
    ],
    name="teacher",
)

student = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
        layers.Flatten(),
        layers.Dense(10),
    ],
    name="student",
)
student_scratch = keras.models.clone_model(student)#新模型与原模型具有相同的结构,但不用原模型的权重,优化器等等

batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()

print(x_train.shape,x_test.shape,x_train.dtype,np.max(x_train),np.min(x_train))

x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))

x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))

print(x_train.shape,x_test.shape,x_train.dtype,np.max(x_train),np.min(x_train))

teacher.summary()#3*3*1*256+256,3*3*256*512+512,flatten:7*7*512,把像素值展平成向量,25088*10+10

teacher.compile(
    optimizer=keras.optimizers.Adam(),#优化器
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),#损失函数:多元交叉熵
    metrics=[keras.metrics.SparseCategoricalAccuracy()],#指标:准确率
)

teacher.fit(x_train, y_train, epochs=5)

teacher.evaluate(x_test, y_test)

distiller = Distiller(student=student, teacher=teacher)#构建知识提炼器

distiller.compile(
    optimizer=keras.optimizers.Adam(),#优化器
    metrics=[keras.metrics.SparseCategoricalAccuracy('acc')],#度量指标
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),#学生损失函数:多元交叉熵
    distillation_loss_fn=keras.losses.KLDivergence(),#知识提炼损失:kld
    alpha=0.1,#提炼权重(用来设定学生和提炼损失的占比)
    temperature=10,#温度,缩放系数
)

distiller.fit(x_train, y_train, epochs=3)#提炼教师到学生(让学生模型能学习教师模型的预测分布)

distiller.evaluate(x_test, y_test)

student_scratch.compile(#这个拷贝模型的职责就是衡量知识提炼中学生模型究竟从教师模型中学到了多少东西
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy('acc')],
)

student_scratch.fit(x_train, y_train, epochs=3)

student_scratch.evaluate(x_test, y_test)

  • 4
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值