检测知识蒸馏

0、序言

常用的模型优化方法分为基于结构设计和非结构设计两种方法,结构设计类的包括mobilenet、ghostnet、peleenet、efficientnet、shufflenet、vovnet等等都充分考虑了计算效率和资源,专为边缘计算设备和移动端设备而生。除此之外的非机构设计的优化方法有知识蒸馏、模型量化和模型减枝,其中知识蒸馏的方法最为简单有效,已经被广泛应用到分类任务,但检测的知识蒸馏方法却很少,这里介绍一种基于特征响应的检测知识蒸馏方法。

1、论文

论文题目:
《IMPROVE OBJECT DETECTION WITH FEATURE-BASED KNOWLEDGE DISTILLATION: TOWARDS ACCURATE AND EFFICIENT DETECTORS》
下载链接:
论文地址
代码链接:
论文给出的代码链接
自己修改的代码链接

2、核心

检测相对于分类做知识蒸馏的难点:(1)检测需要更加关注局部特征和信息,不像分类只关注全局信息,因此检测需要关注每个像素点(pixel),这样带来的问题就是前景和背景的不平衡,前景目标像素点较少,背景的像素点很多。(2)不同的像素点之间是有联系而不是独立的,因此蒸馏也要学习这种像素点之间的关联性。
针对难点论文提出了解决方法:(1)用注意力的方法让蒸馏学习过程中尽可能的关注前景像素点。(2)用NonLocalModule捕捉像素点之间的关联性。
其他贡献点:论文实际上是对2019年论文《Distilling Object Detectors with Fine-grained Feature Imitation》的改进,抛弃了anchor和gt的先验知识,更加灵活。在前景像素点中也有不同重要性的区别,因此用连续的注意力值更加合适,而不是简单的二值mask。同时也说了分类的teacher越强,教的学生不一定越强。但是检测是成正比的,老师强则学生强。
论文思想
示例

3、代码解析

主要是对backbone+fpn出来的feature进行蒸馏学习。
提取像素点间关联性的NonLocalModule代码如下:
Ms = HW · softmax((Gs(AS ) + Gs(AT))/T),
Mc = C ·softmax((Gc(AS ) + Gc(AT))/T).
Ms是spatial的注意力map
Mc是channel的注意力map
loss

import torch
import torch.nn as nn

def dist2(tensor_a, tensor_b, attention_mask=None, channel_attention_mask=None):
    diff = (tensor_a - tensor_b) ** 2
    #   print(diff.size())      batchsize x 1 x W x H,
    #   print(attention_mask.size()) batchsize x 1 x W x H
    diff = diff * attention_mask
    diff = diff * channel_attention_mask
    diff = torch.sum(diff) ** 0.5
    return diff


class NonLocalBlockND(nn.Module):
    def __init__(self, in_channels, inter_channels=None, dimension=2, sub_sample=True, bn_layer=True, downsample_stride=2):
        super(NonLocalBlockND, self).__init__()

        assert dimension in [1, 2, 3]

        self.dimension = dimension
        self.sub_sample = sub_sample

        self.in_channels = in_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        if dimension == 3:
            conv_nd = nn.Conv3d
            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
            bn = nn.BatchNorm3d
        elif dimension == 2:
            conv_nd = nn.Conv2d
            max_pool_layer = nn.MaxPool2d(kernel_size=(downsample_stride, downsample_stride))
            bn = nn.BatchNorm2d
        else:
            conv_nd = nn.Conv1d
            max_pool_layer = nn.MaxPool1d(kernel_size=(2))
            bn = nn.BatchNorm1d

        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                         kernel_size=1, stride=1, padding=0)

        if bn_layer:
            self.W = nn.Sequential(
                conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                        kernel_size=1, stride=1, padding=0),
                bn(self.in_channels)
            )
            nn.init.constant_(self.W[1].weight, 0)
            nn.init.constant_(self.W[1].bias, 0)
        else:
            self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                             kernel_size=1, stride=1, padding=0)
            nn.init.constant_(self.W.weight, 0)
            nn.init.constant_(self.W.bias, 0)

        self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                             kernel_size=1, stride=1, padding=0)

        self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                           kernel_size=1, stride=1, padding=0)

        if sub_sample:
            self.g = nn.Sequential(self.g, max_pool_layer)
            self.phi = nn.Sequential(self.phi, max_pool_layer)

    def forward(self, x):
        '''
        :param x: (b, c, t, h, w)
        :
        :
        '''

        batch_size = x.size(0)  #   2 , 256 , 300 , 300

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)   #   2 , 128 , 150 x 150
        g_x = g_x.permute(0, 2, 1)                                  #   2 , 150 x 150, 128

        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)   #   2 , 128 , 300 x 300
        theta_x = theta_x.permute(0, 2, 1)                                  #   2 , 300 x 300 , 128
        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)       #   2 , 128 , 150 x 150
        f = torch.matmul(theta_x, phi_x)    #   2 , 300x300 , 150x150
        N = f.size(-1)  #   150 x 150
        f_div_C = f / N #   2 , 300x300, 150x150

        y = torch.matmul(f_div_C, g_x)  #   2, 300x300, 128
        y = y.permute(0, 2, 1).contiguous() #   2, 128, 300x300
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y)
        z = W_y + x

        return z

提取注意力mask和kd loss的代码如下:

        # student mimic teacher loss
        t = 0.1
        s_ratio = 1.0
        kd_feat_loss = 0
        kd_channel_loss = 0
        kd_spatial_loss = 0

        #   for channel attention
        c_t = 0.1
        c_s_ratio = 1.0

        if teacher_info is not None:
            t_feats = teacher_info['teacher_feat']
            for _i in range(len(t_feats)):
                t_attention_mask = torch.mean(torch.abs(t_feats[_i]), [1], keepdim=True)
                size = t_attention_mask.size()
                t_attention_mask = t_attention_mask.view(fpn_in_features[0].size(0), -1)
                t_attention_mask = torch.softmax(t_attention_mask / t, dim=1) * size[-1] * size[-2]
                t_attention_mask = t_attention_mask.view(size)

                s_attention_mask = torch.mean(torch.abs(fpn_in_features[_i]), [1], keepdim=True)
                size = s_attention_mask.size()
                s_attention_mask = s_attention_mask.view(fpn_in_features[0].size(0), -1)
                s_attention_mask = torch.softmax(s_attention_mask / t, dim=1) * size[-1] * size[-2]
                s_attention_mask = s_attention_mask.view(size)

                c_t_attention_mask = torch.mean(torch.abs(t_feats[_i]), [2, 3], keepdim=True)  # 2 x 256 x 1 x1
                c_size = c_t_attention_mask.size()
                c_t_attention_mask = c_t_attention_mask.view(fpn_in_features[0].size(0), -1)  # 2 x 256
                c_t_attention_mask = torch.softmax(c_t_attention_mask / c_t, dim=1) * 256
                c_t_attention_mask = c_t_attention_mask.view(c_size)  # 2 x 256 -> 2 x 256 x 1 x 1

                c_s_attention_mask = torch.mean(torch.abs(fpn_in_features[_i]), [2, 3], keepdim=True)  # 2 x 256 x 1 x1
                c_size = c_s_attention_mask.size()
                c_s_attention_mask = c_s_attention_mask.view(fpn_in_features[0].size(0), -1)  # 2 x 256
                c_s_attention_mask = torch.softmax(c_s_attention_mask / c_t, dim=1) * 256
                c_s_attention_mask = c_s_attention_mask.view(c_size)  # 2 x 256 -> 2 x 256 x 1 x 1

                sum_attention_mask = (t_attention_mask + s_attention_mask * s_ratio) / (1 + s_ratio)
                sum_attention_mask = sum_attention_mask.detach()

                c_sum_attention_mask = (c_t_attention_mask + c_s_attention_mask * c_s_ratio) / (1 + c_s_ratio)
                c_sum_attention_mask = c_sum_attention_mask.detach()

                kd_feat_loss += dist2(t_feats[_i], self.kd_adaptation_layers[_i](fpn_in_features[_i]), attention_mask=sum_attention_mask,
                                      channel_attention_mask=c_sum_attention_mask) * 7e-5 * 6
                kd_channel_loss += torch.dist(torch.mean(t_feats[_i], [2, 3]),
                                              self.kd_channel_wise_adaptation[_i](torch.mean(fpn_in_features[_i], [2, 3]))) * 4e-3 * 6
                t_spatial_pool = torch.mean(t_feats[_i], [1]).view(t_feats[_i].size(0), 1, t_feats[_i].size(2),
                                                                   t_feats[_i].size(3))
                s_spatial_pool = torch.mean(fpn_in_features[_i], [1]).view(fpn_in_features[_i].size(0), 1, fpn_in_features[_i].size(2),
                                                             fpn_in_features[_i].size(3))
                kd_spatial_loss += torch.dist(t_spatial_pool, self.kd_spatial_wise_adaptation[_i](s_spatial_pool)) * 4e-3 * 6

        losses.update({'kd_feat_loss': kd_feat_loss})
        losses.update({'kd_channel_loss': kd_channel_loss})
        losses.update({'kd_spatial_loss': kd_spatial_loss})

        kd_nonlocal_loss = 0
        if teacher_info is not None:
            t_feats = teacher_info['teacher_feat']
            for _i in range(len(t_feats)):
                s_relation = self.kd_student_non_local[_i](fpn_in_features[_i])
                t_relation = self.kd_teacher_non_local[_i](t_feats[_i])
                #   print(s_relation.size())
                kd_nonlocal_loss += torch.dist(self.kd_non_local_adaptation[_i](s_relation), t_relation, p=2)
        #losses.update(kd_nonlocal_loss=kd_nonlocal_loss * 7e-5 * 6)
        losses.update({'kd_nonlocal_loss': kd_nonlocal_loss * 7e-5 * 6})

4、总结

Get知识蒸馏模型优化方法,撒花~,说实话,作者的代码写的真的不优美(虽然是清华大神,当然idea最重要,论文刚出来,代码可能没有时间整理,毕竟放出来代码已经很好了,不像某旷,感谢作者)。

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

超超爱AI

土豪请把你的零钱给我点

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值