Focal and Global Knowledge Distillation for Detectors--FGD论文解读

论文:Focal and Global Knowledge Distillation for Detectors

论文:https://arxiv.org/abs/2111.11837

代码:https://github.com/yzd-v/FGD

一,针对问题

1. 目标检测中前背景不平衡问题

        知识蒸馏旨在使学生学习教师的知识,以获得相似的输出从而提升性能。为了探索学生与教师在特征层面的差异,作者首先对二者的特征图进行了可视化。可以看到在空间与通道注意力上,教师与学生均存在较大的差异。其中在空间注意力上,二者在前景中的差异较大,在背景中的差异较小,这会给蒸馏中的学生带来不同的学习难度。

         为了进一步探索前背景对于知识蒸馏的影响,作者分离出前背景进行了蒸馏实验,全图一起蒸馏会导致蒸馏性能的下降,将前景与背景分开学生能够获得更好的表现。

         针对学生与教师注意力的差异,前景与背景的差异,作者提出了重点蒸馏Focal Distillation:分离前背景,并利用教师的空间与通道注意力作为权重,指导学生进行知识蒸馏,计算重点蒸馏损失。

二,方法

整体蒸馏损失计算方式:

C,H,W:feature map的通道时和高宽。

 F^TF^{S}为教师和学生模型的输出。

 2.1 分离前背景

前、背景Mask

设置一个二值MASK:

r代表GT bbox,如果feature map的点落在bbox内则该点为1,否则为0.

 2.2 尺度

尺度Mask

 大小目标focal,前、背景

Hr和Wr为bbox的高和宽,如果一个同时属于多个目标(遮挡场景)选取bbox最小的目标去计算S

 2.2 空间与通道注意力

 空间与通道注意力

 C,H,W:feature map的通道时和高宽。

GG^S ,G^{C} 代表空间注意立和通道注意力机制,

Attention MASK:

T为蒸馏温度 ,论文设置为0.5

2.3 全局蒸馏

全局信息的丢失

        Focal Distillation将前景背景分开进行蒸馏,割断了前背景的联系,缺乏了特征的全局信息的蒸馏。为此提出了全局蒸馏Global Distillation:利用GcBlock分别提取学生与教师的全局信息,并进行全局蒸馏损失的计算。

 使用GCBlock去获取全局信息,使得学生模型从教室模型中学习前背景的联系。

损失计算如下:

    
        self.conv_mask_t = nn.Conv2d(teacher_channels, 1, kernel_size=1)
        self.channel_add_conv_s = nn.Sequential(
            nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),
            nn.LayerNorm([teacher_channels//2, 1, 1]),
            nn.ReLU(inplace=True),  # yapf: disable
            nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))
        self.channel_add_conv_t = nn.Sequential(
            nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),
            nn.LayerNorm([teacher_channels//2, 1, 1]),
            nn.ReLU(inplace=True),  # yapf: disable
            nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))

    def spatial_pool(self, x, in_type):
        batch, channel, width, height = x.size()
        input_x = x
        # [N, C, H * W]
        input_x = input_x.view(batch, channel, height * width)
        # [N, 1, C, H * W]
        input_x = input_x.unsqueeze(1)
        # [N, 1, H, W]
        if in_type == 0:
            context_mask = self.conv_mask_s(x)
        else:
            context_mask = self.conv_mask_t(x)
        # [N, 1, H * W]
        context_mask = context_mask.view(batch, 1, height * width)
        # [N, 1, H * W]
        context_mask = F.softmax(context_mask, dim=2)
        # [N, 1, H * W, 1]
        context_mask = context_mask.unsqueeze(-1)
        # [N, 1, C, 1]
        context = torch.matmul(input_x, context_mask)
        # [N, C, 1, 1]
        context = context.view(batch, channel, 1, 1)

        return context

   
    def get_rela_loss(self, preds_S, preds_T):
        loss_mse = nn.MSELoss(reduction='sum')

        context_s = self.spatial_pool(preds_S, 0)
        context_t = self.spatial_pool(preds_T, 1)

        out_s = preds_S
        out_t = preds_T

        channel_add_s = self.channel_add_conv_s(context_s)
        out_s = out_s + channel_add_s

        channel_add_t = self.channel_add_conv_t(context_t)
        out_t = out_t + channel_add_t

        rela_loss = loss_mse(out_s, out_t)/len(out_s)
        
        return rela_loss

 2.4 最终Loss


alpha=0.001,beta=0.0005

除此之外,利用L_{at}注意力损失来强迫学生模型去逼近教师模型的空间和通道注意力Mask
公式如下:


gamma=0.0005.

最终loss

lambda=0.000005

关于超参

最终效果:

 

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值