蒸馏论文五(Neuron Selectivity Transfer)

本文介绍了一种知识蒸馏的方法(Neuron Selectivity Transfer)

1. 主要思想

在这里插入图片描述
如图所示为神经元选择性迁移的架构。学生网络不仅利用真正的标签训练,而且还模仿了教师网络中间层的激活分布。图中的每个点或三角形表示其对应的滤波器的激活图。

2. 损失函数

最大平均差异(MMD)用作损失函数来衡量教师和学生特征之间的差异。MMD的想法就是求两个随机变量在高维空间中均值的距离
在这里插入图片描述
可以应用内核技巧展开,将上式子转化成如下形式:
在这里插入图片描述
在应用中,上述xy为学生网络和教师网络归一化后的特征图,因此:
在这里插入图片描述
这里的k为核函数,有以下几种可采用的形式:
在这里插入图片描述
在应用中,作者采用了d=2, c=0的多项式核函数。

class NSTLoss(nn.Module):
    """like what you like: knowledge distill via neuron selectivity transfer"""
    def __init__(self):
        super(NSTLoss, self).__init__()
        pass

    def forward(self, g_s, g_t):
        return [self.nst_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]

    def nst_loss(self, f_s, f_t):
        s_H, t_H = f_s.shape[2], f_t.shape[2]
        
        # 通过pooling将teacher features和student features调整为统一大小
        # pooling成两者中较小的size
        if s_H > t_H:
            f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
        elif s_H < t_H:
            f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
        else:
            pass

        f_s = f_s.view(f_s.shape[0], f_s.shape[1], -1)
        f_s = F.normalize(f_s, dim=2)
        f_t = f_t.view(f_t.shape[0], f_t.shape[1], -1)
        f_t = F.normalize(f_t, dim=2)

        # set full_loss as False to avoid unnecessary computation
        full_loss = True
        if full_loss:
            return (self.poly_kernel(f_t, f_t).mean().detach() + self.poly_kernel(f_s, f_s).mean()
                    - 2 * self.poly_kernel(f_s, f_t).mean())
        else:
            return self.poly_kernel(f_s, f_s).mean() - 2 * self.poly_kernel(f_s, f_t).mean()

    def poly_kernel(self, a, b):
        '''d=2, c=0的多项式核函数'''
        return (a@b.transpose(-1,-2)).transpose(-1,-2).pow(2)

3. 训练

# 损失函数
criterion_cls = nn.CrossEntropyLoss()
criterion_div = DistillKL(opt.kd_T)
criterion_kd = NSTLoss()

for idx, data in enumerate(train_loader):
	# ===================forward=====================
	loss_cls = criterion_cls(logit_s, target)
	loss_div = criterion_div(logit_s, logit_t)
	        
    g_s = feat_s[1:-1]
    g_t = feat_t[1:-1]
    loss_group = criterion_kd(g_s, g_t)
    loss_kd = sum(loss_group)
	
	loss = opt.gamma * loss_cls + opt.alpha * loss_div + opt.beta * loss_kd
    # ===================backward=====================
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # ===================meters=====================
    batch_time.update(time.time() - end)
    end = time.time()

其中的feat_s是中间特征层,例如对于resnet8

if preact:
    return [f0, f1_pre, f2_pre, f3_pre, f4], x
else:
    return [f0, f1, f2, f3, f4], x

源代码

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值