Relational knowledge distillation -基于关系建模的模型蒸馏

Paper name

Relational Knowledge Distillation

Paper Reading Note

URL: https://arxiv.org/abs/1904.05068

TL;DR

  • 提出了一种以模型输出的结构信息进行蒸馏的方式,对于metric learning、分类等任务有较大涨点效果

Introduction

  • 当前的SOTA模型基本都需要较大的计算量和存储消耗,一个有希望的解决方向是基于知识蒸馏方式将一个大模型的知识迁移到小模型中
  • 两个问题:
    • 模型中的知识是什么
      • 知识蒸馏将知识定义为模型学习到的输入到输出的映射关系
    • 怎么将模型的知识迁移到另一个模型中
      • 迁移方式是基于teacher模型的输出(hidden or last layer)来对student模型进行训练
  • 作者提出了一种以模型输出的结构信息进行蒸馏的方式,如下图所示
    在这里插入图片描述

Dataset/Algorithm/Model/Experiment Detail

在这里插入图片描述

实现方式
  • 传统的知识蒸馏方式如下,直接惩罚teacher模型和student模型的输出不一致性
    在这里插入图片描述

    • 例如hinton提出的基于KL散度来惩罚softmax(带温度项 τ \tau τ,温度项越大softmax输出越平滑,loss越大,避免训练前期陷入局部最优)的输出分布不一致性
      在这里插入图片描述
    • Romero提出的基于欧式距离来对网络的中间层输出直接计算loss,其中 β \beta β 为一个线性映射用于解决维度不一致问题
      在这里插入图片描述
  • Relational knowledge distillation

    • 计算模型输出的隐藏关系,以隐藏关系为知识迁移信号
      在这里插入图片描述
  • Distance-wise distillation loss

    • 首先计算模型的距离隐藏信息,对于模型的两个输出,以如下方式计算模型输出间的隐藏信息
      在这里插入图片描述
    • 其中 μ \mu μ 为距离规范化参数,为了关注模型输出对之间的相对距离,避免teacher模型和student模型之间的输出scale不匹配造成的影响
      在这里插入图片描述
    • 基于teacher和student的隐藏信息计算损失
      在这里插入图片描述
    • 其中损失计算方式为Huber loss
      在这里插入图片描述
  • Angle-wise distillation loss

    • 对于模型的三个输出,基于角度计算模型的输出的隐藏关系
      在这里插入图片描述
    • 基于模型输出的角度隐藏关系计算损失,计算损失的方式为Huber loss
      在这里插入图片描述
  • 基于RKD loss的模型训练方式

    • 使用RKD单独训练或者与task-specific loss联合训练都行,一般是遍历mini-batch中的所有样本来计算RKD loss
      在这里插入图片描述
    • 蒸馏层选择
      • 理论上可以选择模型的任何层
      • 当输出的绝对值重要的情况下不能单独使用RKD loss,需要与IKD loss或者task-specific loss联合使用
实验结果
  • Distillation to smaller networks
    样本间的关系的重要性对于metric learning任务是非常重要的,metric learning任务是训练一个embedding model来将数据样本映射到流形空间中,该空间中语义相似的样本的距离较近
    在这里插入图片描述
    可以看出rkd效果好,其中没有用l2 norm的情况下rkd的增益效果更大,原因是没有l2 norm能够利用更大的embedding space(l2 norm促使模型的输出lie on the surface of unit-hypersphere),部分student模型效果超过了teacher模型

  • Self-distillation
    在这里插入图片描述

  • Comparison with state-of-the art methods
    在这里插入图片描述

Thoughts

  • 与task-specific loss的权重比调节影响较大,RKD-DA的情况下DA的loss权重配比对于结果也影响较大
  • 对于students模型超过teachers模型现象,原因是teacher模型带有额外信息,比如跨类别关系,是无法被编码到one-hot形式的真值label中的,另外比如连续目标标签,如distance或者angle,都是带有有效信息的,这也无法被编码到二值的gt-lable中
  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值