Learning Efficient Object Detection Models with Knowledge Distillation
之前博客整理的论文都是knowledge distillation及其变体,作为机器学习的一种方法的研究发展历程。从这篇博客开始,我将介绍其在CV领域的一些具体的用法。
本文是knowledge distillation在detection上成功应用的一个例子。
概述
knowledge distillation和hint learning在classification已经很成功了。然而对于detection,soft target不再是单一的类别概率输出,regression、proposal、less voluminous labels(较少的标签)都是在detection种使用distillation的挑战:
本文应该是最先在detection种成功使用distillation的,主要idea有:
- end-to-end 使用distillation方式训练
- loss定义,a) weighted cross entropy loss for classification b) teacher bounded regression loss for knowledge distillation c) adaptation layers for hint learning
Method
整体框架:
总的loss有三大部分。
L
H
i
n
t
L_{Hint}
LHint是使用hint-based方法学习teacher经过backbone后feature的表达;
L
R
P
N
L_{RPN}
LRPN是学习teacher RPN部分的proposals,包括classification®ression两部分loss;
L
R
C
N
L_{RCN}
LRCN是学习teacher fast rcnn detector部分的prediction,也包括classification score®ression factor两部分
loss按种类来分,有三种计算方式,分别用于classification、regression、feature adaption
下面依此介绍:
Knowledge Distillation for Classification with Imbalanced Classes
detection与classification的差异,detection在fg和bg的区分上容易误判,所以引入一个权重加大对bg类的惩罚
w
0
=
1.5
w_0=1.5
w0=1.5,其他
w
i
w_i
wi都是1:
另外,对于简单的数据集且是分类任务,需要设置temperature T,是输出分布更soft,拉近类别间的差距。
而对于detection这样一个本身就比较难的任务来说,很多类都有明显的预测误差,设置T=1时性能是最好的
Knowledge Distillation for Regression with Teacher Bounds
Knowledge Distillation for Regression有一个麻烦就是 regression direction可能和GT相差甚远:
策略是,当student的
R
s
R_s
Rs偏的比teacher还要离谱一些(margin m)时,加入一个loss惩罚
和teacher的表现相近时,就不push了,此项loss为0
可以看到即使偏的离谱的时候,student学习的还是hard label。可能这个没有软分布,teacher输出可能误差也比较大,不如直接学习GT,只是和teacher表现差距太大时,多加一项loss
Hint Learning with Feature Adaptation
作者似乎L1 L2 loss 都进行了尝试:
当hint layer 和 guided layer不匹配时,需要用一个adaption layer进行转换。hint and guided都是FC时就用FC,都是conv layer就用1x1conv来匹配。而且作者发现即便channels数一样时,一个额外的adaptation layer也有助于实现高效的知识迁移
实验结果
可以看出来提升还是挺明显的,用VGGM做detection,效果对比未使用蒸馏提升了四个点。这个网络速度只占VGG16的1/3,虽然效果还是查了蛮多。