【知识蒸馏】多任务模型 feature-based 知识蒸馏实战

一、实现流程

(1)定义学生和教师模型
(2)定义特征蒸馏损失

  • Mimic Loss
  • CWD Loss
  • MGD Loss
  • Feature Loss

(3)使用hook获取需要蒸馏的特征层

  • 定义回调函数
  • 使用hook函数
  • 获取需要蒸馏的挺特征层

(4)计算特征蒸馏损失
(5)计算总损失,反向传播

  • 计算总损失
  • 反向传播

(6)保存蒸馏模型

  • 移除hook
  • 保存蒸馏模型

二、代码实现

(1)定义学生和教师模型

# 学生模型
model = torch.load(args.student_model, map_location=device)
# 教师模型
teacher_model = YoloBody(num_det=config.DET_NUM_CLASSES, num_seg=config.SEG_NUM_CLASSES, phi=args.phi, task="multi", use_aspp=False)
teacher_model.load_state_dict(torch.load(args.teacher_model, map_location=device)['model'])

(2)定义特征蒸馏损失

  • Mimic Loss
class MimicLoss(nn.Module):
    def __init__(self, channels_s, channels_t):
        super(MimicLoss, self).__init__()
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.mse = nn.MSELoss()

    def forward(self, y_s, y_t):
        """Forward computation.
        Args:
            y_s (list): The student model prediction with
                shape (N, C, H, W) in list.
            y_t (list): The teacher model prediction with
                shape (N, C, H, W) in list.
        Return:
            torch.Tensor: The calculated loss value of all stages.
        """
        assert len(y_s) == len(y_t)
        losses = []
        for idx, (s, t) in enumerate(zip(y_s, y_t)):
            assert s.shape == t.shape
            losses.append(self.mse(s, t))
        loss = sum(losses)
        return loss
class CWDLoss(nn.Module):
    """PyTorch version of `Channel-wise Distillation for Semantic Segmentation.
    <https://arxiv.org/abs/2011.13256>`_.
    """

    def __init__(self, channels_s, channels_t,tau=1.0):
        super(CWDLoss, self).__init__()
        self.tau = tau

    def forward(self, y_s, y_t):
        """Forward computation.
        Args:
            y_s (list): The student model prediction with
                shape (N, C, H, W) in list.
            y_t (list): The teacher model prediction with
                shape (N, C, H, W) in list.
        Return:
            torch.Tensor: The calculated loss value of all stages.
        """
        assert len(y_s) == len(y_t)
        losses = []

        for idx, (s, t) in enumerate(zip(y_s, y_t)):

            assert s.shape == t.shape
            
            N, C, H, W = s.shape
            
            # normalize in channel diemension
            softmax_pred_T = F.softmax(t.view(-1, W * H) / self.tau, dim=1)  # [N*C, H*W]
            
            logsoftmax = torch.nn.LogSoftmax(dim=1)
            cost = torch.sum(
                softmax_pred_T * logsoftmax(t.view(-1, W * H) / self.tau) - 
                softmax_pred_T * logsoftmax(s.view(-1, W * H) / self.tau)) * (self.tau ** 2)

            losses.append(cost / (C * N))
        loss = sum(losses)

        return loss
class MGDLoss(nn.Module):
    def __init__(self, channels_s, channels_t, alpha_mgd=0.00002, lambda_mgd=0.65):
        super(MGDLoss, self).__init__()
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.alpha_mgd = alpha_mgd
        self.lambda_mgd = lambda_mgd

        self.generation = [
            nn.Sequential(
                nn.Conv2d(channel, channel, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel, channel, kernel_size=3, padding=1)).to(device) for channel in channels_t
        ]

    def forward(self, y_s, y_t):
        """Forward computation.
        Args:
            y_s (list): The student model prediction with
                shape (N, C, H, W) in list.
            y_t (list): The teacher model prediction with
                shape (N, C, H, W) in list.
        Return:
            torch.Tensor: The calculated loss value of all stages.
        """
        assert len(y_s) == len(y_t)
        losses = []
        for idx, (s, t) in enumerate(zip(y_s, y_t)):
            assert s.shape == t.shape
            losses.append(self.get_dis_loss(s, t, idx) * self.alpha_mgd)
        loss = sum(losses)
        return loss

    def get_dis_loss(self, preds_S, preds_T, idx):
        loss_mse = nn.MSELoss(reduction='sum')
        N, C, H, W = preds_T.shape

        device = preds_S.device
        mat = torch.rand((N, 1, H, W)).to(device)
        mat = torch.where(mat > 1 - self.lambda_mgd, 0, 1).to(device)

        masked_fea = torch.mul(preds_S, mat)
        new_fea = self.generation[idx](masked_fea)

        dis_loss = loss_mse(new_fea, preds_T) / N

        return dis_loss
  • Feature Loss
 class FeatureLoss(nn.Module):
    def __init__(self, channels_s, channels_t, distiller='cwd', loss_weight=1.0):
        super(FeatureLoss, self).__init__()
        self.loss_weight = loss_weight
      
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.align_module = nn.ModuleList([
            nn.Conv2d(channel, tea_channel, kernel_size=1, stride=1, padding=0).to(device)
            for channel, tea_channel in zip(channels_s, channels_t)
        ])
        self.norm = [
            nn.BatchNorm2d(tea_channel, affine=False).to(device)
            for tea_channel in channels_t
        ]
        
        if distiller == 'mimic':
            self.feature_loss = MimicLoss(channels_s, channels_t)
            
        elif distiller == 'mgd':
            self.feature_loss = MGDLoss(channels_s, channels_t)
            
        elif distiller == 'cwd':
            self.feature_loss = CWDLoss(channels_s, channels_t)
        else:
            raise NotImplementedError

    def forward(self, y_s, y_t):
        assert len(y_s) == len(y_t)
        tea_feats = []
        stu_feats = []

        for idx, (s, t) in enumerate(zip(y_s, y_t)):
            s = self.align_module[idx](s)
            s = self.norm[idx](s)
            t = self.norm[idx](t)
            tea_feats.append(t)
            stu_feats.append(s)

        loss = self.feature_loss(stu_feats, tea_feats)
        return self.loss_weight * loss

(3)使用hook获取需要蒸馏的特征层

  • 定义回调函数
activation = {}
def get_activation(name):
    def hook(model, inputs, outputs):
        activation[name] = outputs
    return hook
  • 使用hook函数
def get_hooks():
    hooks = []
    # S-model
    #for k, v in teacher_model._modules.items():
    #     print(f"tmodel._modules_k: {k}; v: {v}")
    hooks.append(model._modules['backbone'].stem.register_forward_hook(get_activation("s_stem")))
    hooks.append(model._modules['backbone'].dark2.register_forward_hook(get_activation("s_dark2")))
    hooks.append(model._modules['backbone'].dark3.register_forward_hook(get_activation("s_dark3")))
    hooks.append(model._modules['backbone'].dark4.register_forward_hook(get_activation("s_dark4")))
    hooks.append(model._modules['backbone'].dark5.register_forward_hook(get_activation("s_dark5")))
    # T-model
    hooks.append(teacher_model._modules['module'].backbone.stem.register_forward_hook(get_activation("t_stem")))
    hooks.append(teacher_model._modules['module'].backbone.dark2.register_forward_hook(get_activation("t_dark2")))
    hooks.append(teacher_model._modules['module'].backbone.dark3.register_forward_hook(get_activation("t_dark3")))
    hooks.append(teacher_model._modules['module'].backbone.dark4.register_forward_hook(get_activation("t_dark4")))
    hooks.append(teacher_model._modules['module'].backbone.dark5.register_forward_hook(get_activation("t_dark5")))
    return hooks
  • 获取需要蒸馏的挺特征层
stu_features = [activation["s_stem"], activation["s_dark2"], activation["s_dark3"],activation["s_dark4"], activation["s_dark5"]]
tea_features = [activation["t_stem"],activation["t_dark2"],activation["t_dark3"],activation["t_dark4"],activation["t_dark5"]]

(4)计算特征蒸馏损失

# 实例化特征蒸馏损失类
channels_s = [16,32,64,128,256]
channels_t = [32, 64, 128, 256,512]
distill_feat_type = 'mimic'
distill_loss = FeatureLoss(channels_s=channels_s, channels_t=channels_t,distiller=distill_feat_type)

# 计算蒸馏损失
distill_weight = 1
dfea_loss = distill_loss(stu_features,tea_features)*distill_weight
print('---------dfea_loss---------- :', dfea_loss)

(5)计算总损失,反向传播

  • 计算总损失
Bev_loss = (Bev_det_distill_loss * config.Ratio_det + Bev_seg_distill_loss * config.Ratio_seg)*(1-distill_.mtl_feature_alpha) +dfea_loss* distill_.mtl_feature_alpha
  • 反向传播
Bev_loss.backward()

(6)保存蒸馏模型

  • 移除hook
 # -------- 移除hook,不然保存模型会报错 ---------#
for hook in hooks:
     hook.remove()    
  • 保存蒸馏模型
torch.save(model, distill_path)
print('save distill model')
  • 5
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BILLY BILLY

你的奖励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值