蒸馏论文一(knowledge distillation)

本文介绍知识蒸馏的经典论文(Distilling the knowledge in a neural network)。核心思想是通过迁移知识,从而通过训练好的大模型得到更加适合推理的小模型。

1. 核心思想

文章的核心思想就是提出用soft target来辅助hard target一起训练,而soft target来自于大模型的预测输出:

1、训练大模型:先用hard target,也就是正常的标签训练大模型。
2、计算soft target:利用训练好的大模型来计算soft target。也就是大模型“软化后”再经过softmax的输出。
3、训练小模型,在小模型的基础上再加一个额外的soft target的损失函数,通过alpha来调节两个损失函数的比重。
4、预测时,将训练好的小模型按常规方式(右图)使用。

在这里插入图片描述

2. 损失函数

class DistillKL(nn.Module):
    """Distilling the Knowledge in a Neural Network"""
    def __init__(self, T):
        super(DistillKL, self).__init__()
        self.T = T

    def forward(self, y_s, y_t):
        p_s = F.log_softmax(y_s/self.T, dim=1)
        p_t = F.softmax(y_t/self.T, dim=1)
        loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
        return loss

知识蒸馏的关键是损失函数的设计,它包括普通的交叉熵损失和建立在soft target基础上的损失。

hard target 包含的信息量(信息熵)很低,soft target包含的信息量大,拥有不同类之间关系的信息。

比如,同时分类驴和马的时候,尽管某张图片是马,但是soft target就不会像hard target那样只有马的index处的值为1,其余为0,而是在驴的部分也会有概率。

这样的好处是,这个图像可能更像驴,而不会去像汽车或者狗之类的,而这样的soft信息存在于概率中,以及标签之间的高低相似性都存在于soft target中。

但是如果soft target是像这样的信息[0.98 0.01 0.01],就意义不大了,所以需要在softmax中增加温度参数T(这个设置在最终训练完之后的推理中是不需要的)。增加softmax后的蒸馏损失函数:

在这里插入图片描述
综合损失函数:
在这里插入图片描述
蒸馏损失的代码实现:

# ==============================蒸馏损失=============================== 
class DistillKL(nn.Module):
    """Distilling the Knowledge in a Neural Network"""
    def __init__(self, T):
        super(DistillKL, self).__init__()
        self.T = T

    def forward(self, y_s, y_t):
        # student网络输出软化后结果
        # log_softmax与softmax没有本质的区别,只不过log_softmax会得到一个正值的loss结果。
        p_s = F.log_softmax(y_s/self.T, dim=1)

        # # teacher网络输出软化后结果
        p_t = F.softmax(y_t/self.T, dim=1)

        # 蒸馏损失采用的是KL散度损失函数
        loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
        return loss

参考文献
深度学习方法(十五):知识蒸馏(Distilling the Knowledge in a Neural Network)

  • 4
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
知识蒸馏Knowledge Distillation)是一种将一个较大的模型的知识转移到一个较小的模型的技术。这种技术通常用于减小模型的大小和推理成本,同时保持模型在任务上的性能。 在Python中,你可以使用以下步骤来实现知识蒸馏: 1. 准备教师模型和学生模型:首先,你需要准备一个较大的教师模型和一个较小的学生模型。教师模型通常是一个预训练的大型模型,例如BERT或其他深度学习模型。学生模型是一个较小的模型,可以是一个浅层的神经网络或者是一个窄的版本的教师模型。 2. 训练教师模型:使用标注数据或其他训练数据集来训练教师模型。这个步骤可以使用常规的深度学习训练方法,例如反向传播和随机梯度下降。 3. 生成教师模型的软标签:使用教师模型对训练数据进行推理,并生成教师模型的软标签。软标签是对每个样本的预测概率分布,而不是传统的单一类别标签。 4. 训练学生模型:使用软标签作为学生模型的目标,使用训练数据集来训练学生模型。学生模型的结构和教师模型可以不同,但通常会尽量保持相似。 5. 进行知识蒸馏:在训练学生模型时,除了使用软标签作为目标,还可以使用教师模型的中间层表示或其他知识来辅助学生模型的训练。这可以通过添加额外的损失函数或使用特定的蒸馏算法来实现。 以上是实现知识蒸馏的一般步骤,具体实现细节可能因应用场景和模型而有所不同。你可以使用深度学习框架(如TensorFlow、PyTorch等)来实现这些步骤,并根据需要进行调整和扩展。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值