Knowledge Distillation论文阅读(2):Learning Efficient Object Detection Models with Knowledge Distillation

由于论文中有很多冗余的话,在这篇文章的阅读中,我只总结比较重要的部分,而不再对论文进行逐字逐句的翻译

Abstract

相比于图像分类,把知识蒸馏应用到多类别的目标检测中是一个很大的挑战:

  • 压缩的模型会造成更严重的退化问题而影响模型检测目标的表现;因为标注标签比较昂贵,所以数据集通常容量相对来说没有那么大
  • 知识蒸馏适用于假设各种被分类的类别都具有同等的重要性;但是在目标检测任务中,情况有很大的区别,因为通常被检测图片的背景都会占很大的比例
  • 目标检测是一个更加复杂的任务,结合了分类和边界框回归的任务
  • 最后,一个额外的挑战是:我们致力于在同一个域内转移知识,不适用任何附加的数据或者label;而其他工作可能依赖于其他领域的数据(例如:高质量或低质量的图像域或者图像和深度域)

为了解决以上的挑战,本文提出了一个方法通过知识蒸馏来训练快速的目标检测模型,贡献如下:

  • 提出了一个端到端(end-to-end)的网络结构通过知识蒸馏来学习紧凑网络的“多类别检测任务”;这是第一次知识蒸馏被用来解决多类目标检测问题
  • 提出了新的损失函数,非常高效地解决了前面的问题;提出了一个使用权重的 weighted cross_entropy 损失来做分类,成功并有效地解决了分类任务在背景和目标之间不平衡的表现;我们还设计了一个 teacher bounded regression loss 用来完成知识蒸馏工作,还设计了一个适应层(adaption layers for hint learning)让 student 网络可以更好地从 teacher 网络的中间层学到有用的信息。
  • 此外,使用多个大型公共 benchmarks 进行全面的评估。证明了在所有公共的 benchmarks 上,表现都是最佳
  • 我们通过将框架泛化和拟合不足问题联系起来,展示对其所有表现的见解。

Method

本文的工作中借鉴了 Faster-RCNN 目标检测的框架;Faster-RCNN 的组成部分如下:

  • 通过卷积层的共享特征提取
  • 一个区域建议网络(RPN)产生目标建议区域
  • 一个分类和回归网络(RCN),对于每一个目标建议区域(object proposal)返回目标检测的分数 (detection score) 和空间矫正向量(spatial adjustment vector)

RCN 和 RPN 都使用卷积层的输出为特征,RCN 同时也使用 RPN 的输出结果作为输入;为了实现高精度的目标检测,对着三个部分的训练是非常重要的。

3.1 Overall Structure

整体网络结构如下:
在这里插入图片描述

  • 我们采用基于提示的学习策略(hint-based learning)来促使 student 网络学习的特征更加接近 teacher 网络

  • 我们把知识蒸馏的方法与 RCN 和 RPN 网络相结合;为了解决目标检测中严重的不平衡问题(背景和检测目标物体的不平衡),我们对知识蒸馏网络使用了一个具有权重的交叉熵损失(weighted cross entropy)

  • 最后,我们把 teacher 网络回归的输出作为 “ 上界 ”,也就是说,如果 student 网络的回归输出比 teacher 网络的表现要好,那么将不会触发损失函数对其进行惩罚

整体的损失函数设计如下:

L R C N = 1 N ∑ i L c l s R C N + λ 1 N ∑ j L r e g R C N L_{RCN}=\frac{1}{N}\sum_iL_{cls}^{RCN} + \lambda\frac{1}{N}\sum_jL_{reg}^{RCN} LRCN=N1iLclsRCN+λN1jLregRCN

L R P N = 1 M ∑ i L c l s R P N + λ 1 M ∑ j L r e g R P N L_{RPN}=\frac{1}{M}\sum_iL_{cls}^{RPN} + \lambda\frac{1}{M}\sum_jL_{reg}^{RPN} LRPN=M1iLclsRPN+λM1jLregRPN

L = L R P N + L R C N + γ L H i n t L=L_{RPN}+L_{RCN}+\gamma L_{Hint} L=LRPN+LRCN+γLHint ·········· ( 1 ) (1) (1)

  • N N N 是输入 RCN 网络的 batch size, M M M 是输入 RPN 网络的 batch size,
  • L c l s L_{cls} Lcls 代表分类器的损失函数,它结合了 hard softmax 的损失(使用 ground_truth label 得到结果的损失)还有知识蒸馏产生的 soft 损失;此外,
  • L r e g L_{reg} Lreg 是边界框bounding box 回归损失 + smoothed L1 正则化损失同时也结合了我们之前提过的 teacher bounded L2 回归损失( teacher 网络的上界损失)来激励 student 网络能够
  • 最后, L H i n t L_{Hint} LHint 代表的是 hint 损失函数,这个帮助 student 网络更好地模仿 teacher 特征的。
  • λ \lambda λ γ \gamma γ 是超参数,用来控制不同损失之间的平衡,暂且设置他们分别为 1 和 0.5。

3.2 Knowledge Distillation for Classification with Imbalanced Classes

  • 假设我们有数据集 { x i , y i } , i = 1 , 2 , 3 , . . . . , n \{x_i,y_i\},i=1,2,3,....,n {xi,yi},i=1,2,3,....,n x i ∈ J x_i∈\Bbb J xiJ 是输入图片;
  • y i ∈ Y y_i∈\Bbb Y yiY 是类别的标签;
  • t t t 表示 teacher 网络, P t = s o f t m a x ( Z t T ) P_t = softmax(\frac{Z_t}{T}) Pt=softmax(TZt) 是 teacher 网络预测的离散的概率分布结果, Z t Z_t Zt 是最终的分数输出;这里 T T T 是一个温度参数(通常设置为1);
  • 同样的,定义 P s = s o f t m a x ( Z s T ) P_s = softmax(\frac{Z_s}{T}) Ps=softmax(TZs) 代表学生网络 s 的预测的离散的概率分布结果。student 网络根据下面损失函数的公式进行优化:

L c l s = μ L h a r d ( P s , y ) + ( 1 − μ ) L s o f t ( P s , P t ) L_{cls}=μL_{hard}(P_s,y)+(1-μ)L_{soft}(P_s,P_t) Lcls=μLhard(Ps,y)+(1μ)Lsoft(Ps,Pt) ·········· ( 2 ) (2) (2)

  • 这里的 L h a r d L_{hard} Lhard 是指 hard loss,是 Faster-RCNN 训练 teacher 网络的时候使用的损失函数(这里可以理解离散的标签是 hard label刚性的标签,给 teacher 网络用的,同样的,使用刚性标签训练时使用的损失函数就是 hard loss;而连续的分布标签是 soft label 是软标签,是给 student 网络用的,使用 soft label 训练 student 网络时使用的就是 soft loss);
  • L s o f t L_{soft} Lsoft是 soft loss,是把 teacher 网络的输出结果当做 student 训练数据的 label 的时候使用的损失函数;
  • μ μ μ 是平衡 hard loss 和 soft loss 的参数。

soft label(从 teacher 网络输出的离散的概率分布)包含了不同类别之间的信息,这些信息是 teacher network 在训练的时候发现的,表现在了 teacher 网络输出的离散概率分布中。通过学习 soft label,student 网络可以继承这些隐含信息。

在目标检测网络的训练中,一个很大的问题就是背景background 和 需要被检测的目标 object 之间存在较大的差异,因为背景一般情况下在图片中占主要的成分;在图像分类中,唯一可能出现错误的地方,就是图片的“前景”不同物体的分类,但是在目标检测中,无法区分背景和前景也将会导致严重的错误,而类别之间的分类错误却变得罕见了。为了解决这个问题,本文提出了一个 class-weighted cross entropy 的损失函数(类别权重交叉熵损失)用来平衡背景和前景的不同类别目标的分类过程:

L s o f t ( P s , P t ) = − ∑ w c P t l o g P s L_{soft}(P_s,P_t)=-\sum w_cP_tlogP_s Lsoft(Ps,Pt)=wcPtlogPs ·········· ( 3 ) (3) (3)

  • 针对背景,我们给他分配一个相对较大的权重,而对其他需要检测的目标种类,我们分配较小的权重。

  • w 0 = 1.5 w_0=1.5 w0=1.5 给背景权重,而 w i = 1 w_i=1 wi=1 给其他的类别目标,实验进行再 PASCAL数据集上

  • P t P_t Pt (即通过 teacher 网络输出的离散概率分布(软标签))仍然很像 hard label(也就是说,通过 teacher 网络之后,这些标签仍然很刚性);这个时候,我们通过温度参数 T T T 来软化这个结果;使用更大的温度参数 T T T 将迫使 t t t 产出更软的标签,使概率接近于零的类别不会被成本函数忽略;

  • 这在简单的分类任务中是有效的。但是对于一些更难的问题,结果就不尽如人意了,错误率会比较高,原因是,更大的 T T T 会引入更多的噪声;而这将不利于学习的结果;因此,在更大数据集上的分类任务或者目标检测任务中我们倾向于不适用 T T T(使 T = 1 T=1 T=1

3.3 Knowledge Distillation for Regression with Teacher Bounds

除了要进行分类任务,目标检测还需要实现边界框bounding box 的回归任务来调整输入 proposal 区域的边界框的位置和尺寸;通常学习好的回归模型是确保目标检测精度的关键。不同于知识蒸馏针对的是离散的类别,teacher 网络的回归的输出可能会很大程度上误导 student 网络,因为真值回归的输出是无限的。此外,teacher 网络可能会提供与 ground_truth 方向完全相反的回归方向。

因此,相比于直接把 teacher 网络的 regression 结果直接当做标签(target),我们尝试设定一个 upper bound(上界)来限定 student 网络;student 回归矢量应该尽可能地靠近 ground_truth 的 label;但如果 student 网络产生的结果超过了 teacher 网络一定的范围,我们不会采用额外的 loss 来惩罚 student 网络;我们把这种方式叫做 teacher bounded regression loss, L b L_b Lb 代表这个回归损失;

在这里插入图片描述

相当于在 L b L_b Lb 里面结合了 L2 损失;上面式子的意思就是,当 R s R_s Rs 和真实标签的偏差和 R t R_t Rt与真实标签的偏差在一定的范围 m m m 之外的时候,才使用损失函数对结果进行惩罚,在一定范围 m m m 之内的时候,不对这个结果进行惩罚。

回归部分的整体损失公式如下:

L r e g = L s L 1 ( R s , y r e g ) + ν L b ( R s , R t , y r e g ) L_{reg}=L_{sL1}(R_s,y_{reg})+\nu L_b(R_s,R_t,y_{reg}) Lreg=LsL1(Rs,yreg)+νLb(Rs,Rt,yreg) ·········· ( 4 ) (4) (4)

  • m m m 是一个范围值

  • y r e g y_{reg} yreg 代表回归的 ground_truth label

  • R s R_s Rs 是student 网络的回归输出结果

  • R t R_t Rt 是 teacher 网络回归的预测结果

  • ν \nu ν 是一个权重参数(实验的时候设为0.5)

  • 这里的 L s L 1 L_{sL1} LsL1 是平滑的 L 1 L1 L1 loss;
    在这里插入图片描述

    smooth L1的好处主要是能够避开 L1 和 L2 loss 的缺点,防止梯度爆炸,网上有很多文章,可以参考。

  • teacher bounded 回归损失 L b L_b Lb 只有在 student 产生的错误比 teacher 大 m m m 或者更多的时候,才会对 student 的训练结果进行惩罚

  • 注意,虽然这里我们在 L b L_b Lb 里面使用了 L 2 L2 L2 loss,但是其他任何回归损失如: L 1 l o s s L1 loss L1loss S m o o t h e d L 1 Smoothed L1 SmoothedL1 都可以结合在 L b L_b Lb

  • 我们的共同损失鼓励了 student 网络在回归方面接近或优于 teacher 网络的结果,但一旦达到了 teacher 网络的水平,就不会对 student 网络的表现产生什么影响了

3.4 Hint Learning with Feature Adaption

知识蒸馏通常完成信息和知识的传递只使用最后一层(final output);在之前的工作中,Romero 等人证明了使用 teacher 网络的中间层表示来作为 hint 可以帮助训练和提升 student 网络的表现;他们使用了特征矢量 V V V Z Z Z 之间的 L2 距离

L H i n t ( V , Z ) = ∣ ∣ V − Z ∣ ∣ 2 2 L_{Hint}(V,Z)=||V-Z||_2^2 LHint(V,Z)=VZ22 ·········· ( 5 ) (5) (5)

  • Z Z Z 代表我们选用的 teacher 网络的中间层作为 hint, V V V 代表 student 网络中被引导层的输出;我们也用 L 1 L1 L1 距离来做了评估:

L H i n t ( V , Z ) = ∣ ∣ V − Z ∣ ∣ 1 L_{Hint}(V,Z)=||V-Z||_1 LHint(V,Z)=VZ1 ·········· ( 6 ) (6) (6)

  • 我们引入了 hint learning,要求引导和被引导的两个层需要有相同数量的神经元(通道,宽度,高度);这句话的意思就是:从 teacher 网络挑选出的 hint 层要与他指导的 student 里面对应的 guided 层的结构相同;
    在这里插入图片描述
  • 为了匹配这两个层的结构,我们在中间引入一个 adaption 层;让 student 中 guided 层的输出结果进入 adaption 层进行结构调整,使得从 adaption 层出来的结果能够具有和 teacher 网络的 hint 层一样的结构。
  • adaption 层本质上是一个全连接层,当然 hint 层和 guided 层也都是全连接层。当 hint 和 guided 层都是卷积层时,我们采用 1 × 1 1×1 1×1的卷积来节省内存。
  • 有趣的是,我们发现多一个 adaption 层可以高效地把知识从一个网络传输到另外一个网络,即使 hint 层和 guided 层的 channel 数是一样的,也同样有效;在 hint 层和 guided 层结构不一样的时候,可以将他们的结构进行调整从而实现匹配和知识传输。
  • 当 hint 层或者 guided 层是卷积层并且 hint 层和 guided 层的分辨率不同时(例如,VGG16, AlexNet) 我们遵循 Compressing deep convolutional networks using vector quantization 中引入的填充技巧(padding trick)来匹配输出的数量
  • 1
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

暖仔会飞

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值