基于keras的知识蒸馏(Knowledge Distillation)-分类与回归

本文介绍了知识蒸馏的概念,即通过训练小型学生模型来复制大型教师模型的行为。在TensorFlow和Keras中,通过定制Distiller类,利用损失函数和温度参数来软化概率分布,从而将教师模型的知识转移给学生模型。实验中,教师和学生模型均为卷积神经网络,在MNIST数据集上进行训练和评估,展示了知识蒸馏能提高学生模型的性能。
摘要由CSDN通过智能技术生成

Knowledge Distillation

Introduction to Knowledge Distillation

知识提取是一种模型压缩过程,其中对小(学生)模型进行训练,以匹配预先训练的大(教师)模型。通过最小化损失函数,将知识从教师模型转移到学生身上,目的是匹配软化的教师逻辑和基本事实标签。

通过在softmax中应用“温度”标度函数来软化logits,有效地平滑了概率分布,并揭示了教师学习到的课堂间关系。

Hinton et al. (2015)

导入基础库

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

构造Distiller类

自定义Distiller()类覆盖Model方法train_step、test_step和compile()。为使用蒸馏器,我们需要:

  • 训练有素的教师模型
  • 要训练的学生模型
  • 关于学生预测和基本事实之间差异的学生损失函数
  • 关于学生软预测和教师软标签之间差异的蒸馏损失函数以及温度
  • 衡量学生体重和蒸馏损失的阿尔法因素
  • 针对学生的优化器和(可选)评估绩效的指标

在train_step方法中,我们执行教师和学生的前向传递,分别通过α和1-alpha对student_loss和distraction_loss进行加权来计算损失,并执行后向传递。注意:只有学生权重会更新,因此我们只计算学生权重的梯度。

在test_step方法中,我们在提供的数据集上评估学生模型。

class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        """ Configure the distiller.

        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        # Unpack data
        x, y = data

        # Forward pass of teacher
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of student
            student_predictions = self.student(x, training=True)

            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)

            # Compute scaled distillation loss from https://arxiv.org/abs/1503.02531
            # The magnitudes of the gradients produced by the soft targets scale
            # as 1/T^2, multiply them by T^2 when using both hard and soft targets.
            distillation_loss = (
                self.distillation_loss_fn(
                    tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                    tf.nn.softmax(student_predictions / self.temperature, axis=1),
                )
                * self.temperature**2
            )

            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

        # Compute gradients
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, student_predictions)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
        # Unpack the data
        x, y = data

        # Compute predictions
        y_prediction = self.student(x, training=False)

        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results

创建学生和教师模型

首先,创建一个教师模型和一个较小的学生模型。这两个模型都是卷积神经网络,使用Sequential()创建,也可以是其他Keras模型。

# Create the teacher
teacher = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
        layers.Flatten(),
        layers.Dense(10),
    ],
    name="teacher",
)

# Create the student
student = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
        layers.Flatten(),
        layers.Dense(10),
    ],
    name="student",
)

# Clone student for later comparison
student_scratch = keras.models.clone_model(student)

准备数据集

用于训练教师和提取教师的数据集是MNIST,并且该过程对于任何其他数据集都是等效的,例如CIFAR-10,只要选择合适的模型。学生和老师都在训练集上接受训练,并在测试集上进行评估。

# Prepare the train and test dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))

x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))

训练教师模型

在知识提炼中,我们假设老师是经过训练和固定的。因此,我们从以通常的方式在训练集上训练教师模型开始。

# Train teacher as usual
teacher.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train and evaluate teacher on data.
teacher.fit(x_train, y_train, epochs=5)
teacher.evaluate(x_test, y_test)

 

 

从老师蒸馏到学生

已经训练了教师模型,只需要初始化Distiller(学生,教师)实例,用所需的损失、超参数和优化器对其进行compile(),并将教师提取给学生。从头开始训练学生进行比较

# Initialize and compile distiller
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),#需要进行回归的时候可相应替换损失函数
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10,
)

# Distill teacher to student
distiller.fit(x_train, y_train, epochs=3)

# Evaluate student on test dataset
distiller.evaluate(x_test, y_test)

 

 

从头开始训练学生进行比较 

# Train student as doen usually
student_scratch.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train and evaluate student trained from scratch.
student_scratch.fit(x_train, y_train, epochs=3)
student_scratch.evaluate(x_test, y_test)

 

如果老师接受了5个epoch的训练,而学生在这个老师身上被提炼了3个epoch,那么在这个例子中,与从头开始训练相同的学生模型相比,甚至与老师本身相比,应该会体验到一种成绩提升。应该期望老师的准确率在97.6%左右,从头开始训练的学生的准确率应该在97.6%附近,蒸馏的学生应该在98.1%左右。

 

首先,我们需要导入keras和相关的库: ```python import keras from keras.models import Sequential from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense from keras.datasets import mnist from keras.utils import to_categorical ``` 然后我们加载mnist数据集,并对其进行预处理: ```python # 加载mnist数据集 (x_train, y_train), (x_test, y_test) = mnist.load_data() # 数据预处理 x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype('float32') / 255 x_test = x_test.reshape(x_test.shape[0], 28, 28, 1).astype('float32') / 255 y_train = to_categorical(y_train, 10) y_test = to_categorical(y_test, 10) ``` 接下来,我们搭建Lenet-5网络模型: ```python # 创建lenet-5网络 model = Sequential() model.add(Conv2D(filters=6, kernel_size=(5, 5), strides=(1, 1), activation='tanh', input_shape=(28, 28, 1))) model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2))) model.add(Conv2D(filters=16, kernel_size=(5, 5), strides=(1, 1), activation='tanh')) model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2))) model.add(Flatten()) model.add(Dense(120, activation='tanh')) model.add(Dense(84, activation='tanh')) model.add(Dense(10, activation='softmax')) ``` 最后,我们编译并训练模型,并在测试集上进行评估: ```python # 编译模型 model.compile(loss=keras.losses.categorical_crossentropy, optimizer=keras.optimizers.Adadelta(), metrics=['accuracy']) # 训练模型 model.fit(x_train, y_train, batch_size=128, epochs=10, validation_data=(x_test, y_test)) # 在测试集上评估模型性能 score = model.evaluate(x_test, y_test, verbose=0) print('Test loss:', score[0]) print('Test accuracy:', score[1]) ``` 通过以上步骤,我们就成功基于keras搭建了一个Lenet-5网络,并实现了对mnist手写数字的识别。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值