Distilling Object Detectors with Fine-grained Feature Imitation的复现

复现基于原文开源代码:https://github.com/twangnh/Distilling-Object-Detectors

代码问题和细节可以在我的github讨论:

https://github.com/HqWei/Distillation-of-Faster-rcnn

这篇文章的本质是对于目标检测在Feature Level的蒸馏的改进,你首先得实现检测的特征图层面的蒸馏,实现起来比较简单:

sup_feature=output_teacher['features'][0]
stu_feature=output['features'][0]
#model_adap是一个卷积层+Relu层:作用是把student网络的特征图变得和teacher一样,通道数相同,后面才能直接求L2距离。
stu_feature_adap=model_adap(stu_feature)

start_weigth=cfg_feature_distillation.get('start_weigth')
end_weigth=cfg_feature_distillation.get('end_weigth')

imitation_loss_weigth=start_weigth+(end_weigth-start_weigth)*(float(epoch)/max_epoch)
# imitation_loss_weigth=0.0001
#L2距离:特征图对应位置的差值的平方和
sup_loss = (torch.pow(sup_feature - stu_feature_adap, 2)).sum()
sup_loss = sup_loss * imitation_loss_weigth

然后就是本文的核心创新点:(没必要让student在整个特征图模仿teacher,只需要在GT附近模仿):

主要难点在于mask的生成,原文是:

Specifically, as shown in Fig. 2, for each ground truth box, we compute the IOU between it and all anchors, which
forms a W × H × K IOU map m. Here W and H denote width and height of the feature map, and K indicates the
K preset anchor boxes. Then we find the largest IOU value M = max(m), times the thresholding factor ψ to obtain
a filter threshold F = ψ ∗ M . With F , we filter the IOU map to keep those larger then F locations and combine them
with OR operation to get a W × H mask.

大意是:先计算GT框和所有anchor的IOU,得到一个WxHxK的IOUmap:称为m;W和H是特征图的高宽,K是单个点产生的anchor的数量(如anchor-rate为0.5,1,2;scale为2,4,8,16,32时,K=3x5),也就是一个Anchor得到一个WxH的IOU得分图,这个得分图里面每个点(WxH个)的值指的是该位置产生的anchor与GT的IOU,对比K个WxH,取k个最大的值因为只要有一个IOU大说明那地方离GT近。而最后我们得到的是一个WxH的mask,也就是只有0和1,怎么得到的呢?设定一个阈值,这个阈值在这篇文中比较巧妙,阈值为最大值的0.5;最后通过或运算合并。

我复现的代码:

#生成mask,单个batch中每张图一个对应的mask 
 mask_batch = []
    if cfg.get('need_mask',None):
        for i in range(B):
            K1 = int(all_anchors.shape[0] / (height * width))
            # A : sum of GT numbers in  batches
            A = gt_bboxes[i].shape[0]
            gt_boxes = gt_bboxes[i] #torch.cat((gt_bboxes[0], gt_bboxes[1]), 0)
            gt_boxes = gt_boxes.view(1, gt_boxes.shape[0], gt_boxes.shape[1])
            IOU_map = bbox_overlaps_batch(all_anchors, gt_boxes).view(height, width, K1, A)
            max_iou, _ = torch.max(IOU_map.view(height * width * K1,
                                                  A), dim=0)
            mask_per_im = torch.zeros([height, width], dtype=torch.int64).cuda()
            #walk through every gt box
            for k in range(gt_boxes.shape[1]):
                if torch.sum(gt_boxes[0][k]) == 0.:
                    break
                max_iou_per_gt = max_iou[k] * 0.5
                mask_per_gt = torch.sum(IOU_map[:, :, :, k] > max_iou_per_gt,
                                        dim=2)
                mask_per_im += mask_per_gt
            mask_batch.append(mask_per_im)
#其中计算IOUmap的代码,采用原文code:
def bbox_overlaps_batch(anchors, gt_boxes):
    """
    anchors: (N, 4) ndarray of float
    gt_boxes: (b, K, 5) ndarray of float

    overlaps: (N, K) ndarray of overlap between boxes and query_boxes
    """
    batch_size = len(gt_boxes)
    batch_size=1
    # for i in range(batch_size):


    # gt_boxes=gt_boxes.view(batch_size,gt_boxes.shape[0],gt_boxes.shape[1])

    if anchors.dim() == 2:

        N = anchors.size(0)
        K = gt_boxes.size(1)

        anchors = anchors.view(1, N, 4).expand(batch_size, N, 4).contiguous()
        gt_boxes = gt_boxes[:,:,:4].contiguous()


        gt_boxes_x = (gt_boxes[:,:,2] - gt_boxes[:,:,0] + 1)
        gt_boxes_y = (gt_boxes[:,:,3] - gt_boxes[:,:,1] + 1)
        gt_boxes_area = (gt_boxes_x * gt_boxes_y).view(batch_size, 1, K)

        anchors_boxes_x = (anchors[:,:,2] - anchors[:,:,0] + 1)
        anchors_boxes_y = (anchors[:,:,3] - anchors[:,:,1] + 1)
        anchors_area = (anchors_boxes_x * anchors_boxes_y).view(batch_size, N, 1)

        gt_area_zero = (gt_boxes_x == 1) & (gt_boxes_y == 1)
        anchors_area_zero = (anchors_boxes_x == 1) & (anchors_boxes_y == 1)

        boxes = anchors.view(batch_size, N, 1, 4).expand(batch_size, N, K, 4)
        query_boxes = gt_boxes.view(batch_size, 1, K, 4).expand(batch_size, N, K, 4)

        iw = (torch.min(boxes[:,:,:,2], query_boxes[:,:,:,2]) -
            torch.max(boxes[:,:,:,0], query_boxes[:,:,:,0]) + 1)
        iw[iw < 0] = 0

        ih = (torch.min(boxes[:,:,:,3], query_boxes[:,:,:,3]) -
            torch.max(boxes[:,:,:,1], query_boxes[:,:,:,1]) + 1)
        ih[ih < 0] = 0
        ua = anchors_area + gt_boxes_area - (iw * ih)
        overlaps = iw * ih / ua

        # mask the overlap here.
        overlaps.masked_fill_(gt_area_zero.view(batch_size, 1, K).expand(batch_size, N, K), 0)
        overlaps.masked_fill_(anchors_area_zero.view(batch_size, N, 1).expand(batch_size, N, K), -1)

    elif anchors.dim() == 3:
        N = anchors.size(1)
        K = gt_boxes.size(1)

        if anchors.size(2) == 4:
            anchors = anchors[:,:,:4].contiguous()
        else:
            anchors = anchors[:,:,1:5].contiguous()

        gt_boxes = gt_boxes[:,:,:4].contiguous()

        gt_boxes_x = (gt_boxes[:,:,2] - gt_boxes[:,:,0] + 1)
        gt_boxes_y = (gt_boxes[:,:,3] - gt_boxes[:,:,1] + 1)
        gt_boxes_area = (gt_boxes_x * gt_boxes_y).view(batch_size, 1, K)

        anchors_boxes_x = (anchors[:,:,2] - anchors[:,:,0] + 1)
        anchors_boxes_y = (anchors[:,:,3] - anchors[:,:,1] + 1)
        anchors_area = (anchors_boxes_x * anchors_boxes_y).view(batch_size, N, 1)

        gt_area_zero = (gt_boxes_x == 1) & (gt_boxes_y == 1)
        anchors_area_zero = (anchors_boxes_x == 1) & (anchors_boxes_y == 1)

        boxes = anchors.view(batch_size, N, 1, 4).expand(batch_size, N, K, 4)
        query_boxes = gt_boxes.view(batch_size, 1, K, 4).expand(batch_size, N, K, 4)

        iw = (torch.min(boxes[:,:,:,2], query_boxes[:,:,:,2]) -
            torch.max(boxes[:,:,:,0], query_boxes[:,:,:,0]) + 1)
        iw[iw < 0] = 0

        ih = (torch.min(boxes[:,:,:,3], query_boxes[:,:,:,3]) -
            torch.max(boxes[:,:,:,1], query_boxes[:,:,:,1]) + 1)
        ih[ih < 0] = 0
        ua = anchors_area + gt_boxes_area - (iw * ih)

        overlaps = iw * ih / ua

        # mask the overlap here.
        overlaps.masked_fill_(gt_area_zero.view(batch_size, 1, K).expand(batch_size, N, K), 0)
        overlaps.masked_fill_(anchors_area_zero.view(batch_size, N, 1).expand(batch_size, N, K), -1)
    else:
        raise ValueError('anchors input dimension is not correct.')

    return overlaps
#主程序中调用位置:
      '''
        Feature level distillation:
        '''
        # sup_loss = (torch.pow(sup_feature - stu_feature_adap, 2) * mask_batch).sum() / norms
        # sup_loss = sup_loss * args.imitation_loss_weigth
        if cfg_distillation.get('feature_distillation', None):
            cfg_feature_distillation=cfg_distillation.get('feature_distillation')
            sup_feature=output_teacher['features'][0]
            stu_feature=output['features'][0]
            stu_feature_adap=model_adap(stu_feature)


            start_weigth=cfg_feature_distillation.get('start_weigth')
            end_weigth=cfg_feature_distillation.get('end_weigth')
            imitation_loss_weigth = start_weigth + (end_weigth - start_weigth) * (float(epoch) / max_epoch)
            if cfg_feature_distillation.get('start_weigth', None):
                mask_batch = output_teacher['RoINet.mask_batch']
                mask_list = []
                for mask in mask_batch:
                    mask = (mask > 0).float().unsqueeze(0)
                    mask_list.append(mask)
                mask_batch = torch.stack(mask_list, dim=0)
                norms = mask_batch.sum() * 2
                sup_loss = (torch.pow(sup_feature - stu_feature_adap, 2) * mask_batch).sum() / norms
            else:
                sup_loss = (torch.pow(sup_feature - stu_feature_adap, 2)).sum()

            # imitation_loss_weigth=0.0001

            sup_loss = sup_loss * imitation_loss_weigth
            output['sup.loss']=sup_loss

Adaptation-module的设计:特征图大小通过一个卷积层变换大小:OUT_W=(IN_W+2PADDING-Kernel_size)/Stride+1

import torch.nn as nn
import torch.nn.functional as F


class Stu_Feature_Adap(nn.Module):

	def __init__(self,input_channel=256, output_channel=1024,kernel_size=2,padding=0):
		super(Stu_Feature_Adap, self).__init__()

		self.conv1 = nn.Conv2d(input_channel, output_channel, kernel_size=kernel_size, padding=padding)
		self.relu = nn.ReLU()
		# self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1)
		# self.conv3 = nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1)
		# self.conv4 = nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1)
		# self.classifier = nn.Conv2d(ndf*8, 1, kernel_size=4, stride=2, padding=1)
        #
		# self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
		#self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear')
		#self.sigmoid = nn.Sigmoid()


	def forward(self, x):
		x = self.conv1(x)
		x = self.relu(x)
		# x = self.leaky_relu(x)
		# x = self.conv2(x)
		# x = self.leaky_relu(x)
		# x = self.conv3(x)
		# x = self.leaky_relu(x)
		# x = self.conv4(x)
		# x = self.leaky_relu(x)
		# x = self.classifier(x)
		# #x = self.up_sample(x)
		# #x = self.sigmoid(x)

		return x

 

  • 6
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
将神经网络中的知识进行提取,是一种将模型的信息转化为更为简洁和易于理解形式的过程。 神经网络是一种由许多神经元组成的复杂计算模型,它们通过学习和调整权重来解决各种问题。然而,神经网络通常具有大量的参数和复杂的结构,这使得它们难以解释和应用到其他领域。因此,我们需要一种方法来提取和总结神经网络中的知识,以便更好地理解和应用这些模型。 在进行神经网络知识提取时,有几种常见的方法。一种常见的方法是使用可视化技术,如热力图、激活图和网络结构图等,来可视化网络中不同层的活动模式。这些可视化技术能够帮助我们发现网络中的模式和特征,并从中推断出网络的知识。 另一种方法是使用特征提取技术,如卷积神经网络(CNN)的滤波器、自动编码器的隐藏层和循环神经网络(RNN)的隐状态等,来提取网络学习到的重要特征。这些重要特征可以帮助我们更好地理解网络学习到的信息,并将其应用到其他问题中。 此外,还有一种被称为知识蒸馏的技术,它通过训练一个较小的模型来提取大型模型中的知识。知识蒸馏通过引入目标函数和额外的训练策略,使小模型能够学习到大模型中的重要知识,并在不损失太多性能的情况下将其应用到实际问题中。 总而言之,提取神经网络中的知识是一项重要任务,它能够帮助我们更好地理解和应用这些复杂的模型。通过可视化、特征提取和知识蒸馏等方法,我们能够从神经网络中提取出有用的信息,并将其应用到其他领域或解决其他问题中。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值