CVPR 2022 | 针对目标检测的重点与全局知识蒸馏
文章链接:arxiv.org/abs/2111.11837
代码链接:github.com/yzd-v/FGD
讲解部分:https://zhuanlan.zhihu.com/p/477707304
代码部分
在网络计算损失部分
#----------------------#
# 前向传播
#----------------------#
outputs = model_train(images)#学生网络的损失
T_pred = teacher(images)#教师网络的损失
loss_value_all = 0
#----------------------#
# 计算损失yolov5为例
#----------------------#
#outputs包含3个特征层,分别是[20,20],[40,40],[80,80]
#teacher_loss就是FGD损失
#总的损失就是教师的损失加上学生原来的损失,我感觉二者差异过大,所以我在教师损失上又*0.1做的,这儿并未添加
for l in range(len(outputs)):
loss_item1 = yolo_loss(l, outputs[l], targets, y_trues[l])
loss_item2=teacher_loss(outputs[l],T_pred[l],targets)
loss_value_all += loss_item1+loss_item2
loss_value = loss_value_all
那么teacher_loss部分是什么呢
以20x20的特征层来说, 该特征层的shape为[N,3 * (5 + num_classes),20,20]
我的类别数为2,所以下面的通道为21
teacher_loss = FeatureLoss(21, 21)
import torch.nn as nn
import torch.nn.functional as F
import torch
class FeatureLoss(nn.Module):
"""PyTorch version of `Focal and Global Knowledge Distillation for Detectors`
Args:
student_channels(int): Number of channels in the student's feature map.
teacher_channels(int): Number of channels in the teacher's feature map.
temp (float, optional): Temperature coefficient. Defaults to 0.5.
name (str): the loss name of the layer
alpha_fgd (float, optional): Weight of fg_loss. Defaults to 0.001
beta_fgd (float, optional): Weight of bg_loss. Defaults to 0.0005
gamma_fgd (float, optional): Weight of mask_loss. Defaults to 0.001
lambda_fgd (float, optional): Weight of relation_loss. Defaults to 0.000005
"""
# student_channels,teacher_channels对应于FPN层的(N,C,H,W)里的C
def __init__(self,
student_channels,
teacher_channels,
name='Distll',
temp=0.5,
alpha_fgd=0.001,
beta_fgd=0.0005,
gamma_fgd=0.001,
lambda_fgd=0.000005,
):
super(FeatureLoss, self).__init__()
self.temp = temp
self.alpha_fgd = alpha_fgd
self.beta_fgd = beta_fgd
self.gamma_fgd = gamma_fgd
self.lambda_fgd = lambda_fgd
# 如果二者通道不同,就变成相同
if student_channels != teacher_channels:
self.align = nn.Conv2d(student_channels, teacher_channels, kernel_size=1, stride=1, padding=0)
else:
self.align = None
self.conv_mask_s = nn.Conv2d(teacher_channels, 1, kernel_size=1)
self.conv_mask_t = nn.Conv2d(teacher_channels, 1, kernel_size=1)
self.channel_add_conv_s = nn.Sequential(
nn.Conv2d(teacher_channels, teacher_channels // 2, kernel_size=1),
nn.LayerNorm([teacher_channels // 2, 1, 1]),
nn.ReLU(inplace=True), # yapf: disable
nn.Conv2d(teacher_channels //