Knowledge Distillation with the Reused Teacher Classifier论文阅读+代码解析

论文地址点这里

一. 介绍

给定一个具有大量参数的强大教师模型,知识蒸馏(KD)的目标是帮助另一个参数化较少的学生模型获得与较大的教师模型相似的泛化能力。实现这一目标的一种直接方法是在给定相同输入的情况下对齐它们的logits或类预测。由于KD技术概念简单,实用有效,在各种应用中都取得了巨大的成功。
在知识蒸馏中,学生模型和教师模型的差距成为一项关键性的挑战。
在本文中,我们提出了一个简单的知识蒸馏技术,并证明它可以显著弥合教师和学生模型之间的性能差距,而不需要详细的知识表示。我们提出的“SimKD”技术如下图所示。我们认为,教师模型强大的班级预测能力不仅归功于这些表达特征,而且同样重要的是,一个鉴别分类器。基于这一论点,我们在分类器的上一层通过特征对齐来训练学生模型,并直接复制教师分类器进行学生推理。这样,如果我们能将学生的特征与教师模型的特征完美地对齐,他们的表现差距就会消失。也就是说,单独的特征对齐错误解释了学生推理的准确性,这使得我们的知识转移更容易理解。根据我们的实验结果,单个 l 2 l_2 l2损失的特征对齐已经出奇地好。如此简单的损耗使我们不必像以前的工作那样仔细地调优超参数,以平衡多个损耗的影响。
在这里插入图片描述

二. 方法

2.1 标准的知识蒸馏模块

一般来说,当前流行的为图像分类任务设计的深度神经网络可以看作是一个具有多个非线性层的特征编码器(特征提取层)与一个通常包含单个具有softmax激活函数的全连接层的分类器进行组合的结构。这两个组件都将通过反向传播算法进行端到端训练。
给定一个训练样本 x \boldsymbol{x} x以及其对应的标签 y \boldsymbol{y} y,我们可以使用特征提取参数进行特征的计算: f s = F s ( x ; θ s ) ∈ R C s \boldsymbol{f}^s=\mathcal{F}^s\left(\boldsymbol{x} ; \boldsymbol{\theta}^s\right) \in \mathbb{R}^{C_s} fs=Fs(x;θs)RCs。该特性随后通过权重 W s ∈ R K × C s \boldsymbol{W}^s \in \mathbb{R}^{K \times C_s} WsRK×Cs传递给分类器,从而获得输出: g s = W s f s ∈ R K \boldsymbol{g}^s=\boldsymbol{W}^s \boldsymbol{f}^s \in \mathbb{R}^K gs=WsfsRK,之后使用一个激活函数(softmax)计算类的预测损失 p s = σ ( g s / T ) ∈ R K \boldsymbol{p}^s=\sigma\left(\boldsymbol{g}^s / T\right) \in \mathbb{R}^K ps=σ(gs/T)RK,其中计算如下:
p i s = exp ⁡ ( g i s / T ) ∑ j = 1 K exp ⁡ ( g j s / T ) , (1) p_i^s=\frac{\exp \left(g_i^s / T\right)}{\sum_{j=1}^K \exp \left(g_j^s / T\right)}, \tag1 pis=j=1Kexp(gjs/T)exp(gis/T),(1)
这里的 i i i表示为第 i i i个样本。 T T T为一个温度的超参数负责软化输出结果。
传统的知识蒸馏由两部分组成:一个是交叉熵损失,另一个是KL距离的损失:
L K D = L C E ( y , p s ) ⏟ T = 1 + T 2 L K L ( p t , p s ) ⏟ T > 1 . (2) \mathcal{L}_{\mathrm{KD}}=\underbrace{\mathcal{L}_{\mathrm{CE}}\left(\boldsymbol{y}, \boldsymbol{p}^s\right)}_{T=1}+\underbrace{T^2 \mathcal{L}_{\mathrm{KL}}\left(\boldsymbol{p}^t, \boldsymbol{p}^s\right)}_{T>1} . \tag2 LKD=T=1 LCE(y,ps)+T>1 T2LKL(pt,ps).(2)

2.2 简单的知识蒸馏(本文的方法)

在这里插入图片描述
近年来,提出了各种特征蒸馏方法。这些工作主要是从中级师生层对中收集和传输额外的梯度信息,以便更好地训练学生特征编码器(上图(b))。然而,它们的成功在很大程度上依赖于那些特别设计的知识表示,以包含适当的归纳偏差,以及精心选择的超参数来平衡不同损失的影响。两者都是劳动密集型和耗时的。也很难断定某一类型的表征在学生培养中所起的实际作用。
相反,我们提出了一种简单的知识蒸馏技术,称为SimKD,它摆脱了这些严格的要求,同时在大量的实验中仍然获得了最先进的结果。如上图©所示,SimKD的一个关键组成部分是“分类器重用”操作,即我们直接借用预先训练好的教师分类器进行学生推理,而不是训练一个新的分类器。这样就不需要用标签信息来计算交叉熵损失,使得特征对齐损失成为产生梯度的唯一来源。
我们认为教师分类器中包含的鉴别信息很重要,但在KD的文献中很大程度上被忽略了。然后,我们为它的重要作用提供了一个合理的解释。考虑这样一种情况,一个模型被要求处理几个具有不同数据分布的任务,一个基本的做法是冻结或共享一些浅层作为跨不同任务的特征提取器,同时微调最后一层以学习特定于任务的信息。在这种单模型多任务设置中,现有研究认为任务不变信息可以共享,而任务特定信息需要独立识别,通常由最终分类器进行识别。对于在同一数据集上训练具有不同能力的教师和学生模型的KD,类似地,我们可以合理地认为,数据中有一些能力不变的信息很容易通过不同的模型获得,而强大的教师模型可能包含额外的基本的能力特定信息,而简单的学生模型很难获得。此外,我们假设大多数特定于能力的信息包含在深层中,并期望重用这些层,甚至只有最后的分类器将有助于学生的训练。
基于这一假设,我们为学生模型提供了教师分类器进行推理,并强制其提取的特征与下面的这个 l 2 l_2 l2损失函数相匹配:
L SimKD  = ∥ f t − P ( f s ) ∥ 2 2 , (3) \mathcal{L}_{\text {SimKD }}=\left\|\boldsymbol{f}^t-\mathcal{P}\left(\boldsymbol{f}^s\right)\right\|_2^2, \tag3 LSimKD = ftP(fs) 22,(3)
其中 P \mathcal{P} P表示为将学生网络的特征输出与教师网络的特征输出进行对齐的函数。
有些令人惊讶的是,通过这种简单的技术,师生压缩中的性能下降将得到极大的缓解。随着高推理精度,这种单损失公式的简单性为我们的SimKD提供了良好的解释性。注意,来自预训练的教师模型的重用部分被允许合并更多的层,但不限于最终的分类器。通常情况下,重用的层数越多,学生的准确率越高,但会增加推理的负担。

三. 代码解析

代码地址点这里
(本文的教师网络模型需要进行预训练)
在本方法中,教师网络和学生网络直接使用普通的的卷积神经网络即可(如ResNet18),这里没有变化,而本文中我们需要重新利用教师网络的分类器的功能,同时需要保证学生网络和教师网路特征输出大小一致,这里添加了一个新的中间网络层,如下:

class SimKD(nn.Module):
    """CVPR-2022: Knowledge Distillation with the Reused Teacher Classifier"""
    """ s_n为学生模型输入的特征的大小,t_n为教师模型输出的特征大小"""
    def __init__(self, *, s_n, t_n, factor=2): 
        super(SimKD, self).__init__()
       
        self.avg_pool = nn.AdaptiveAvgPool2d((1,1))       

        def conv1x1(in_channels, out_channels, stride=1):
            return nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, stride=stride, bias=False)
        def conv3x3(in_channels, out_channels, stride=1, groups=1):
            return nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False, groups=groups)
        
        # A bottleneck design to reduce extra parameters
        setattr(self, 'transfer', nn.Sequential(
            conv1x1(s_n, t_n//factor),
            nn.BatchNorm2d(t_n//factor),
            nn.ReLU(inplace=True),
            conv3x3(t_n//factor, t_n//factor),
            # depthwise convolution
            #conv3x3(t_n//factor, t_n//factor, groups=t_n//factor),
            nn.BatchNorm2d(t_n//factor),
            nn.ReLU(inplace=True),
            conv1x1(t_n//factor, t_n),
            nn.BatchNorm2d(t_n),
            nn.ReLU(inplace=True),
            ))
        
    def forward(self, feat_s, feat_t, cls_t):
        
        # Spatial Dimension Alignment
        s_H, t_H = feat_s.shape[2], feat_t.shape[2]
        if s_H > t_H:
            source = F.adaptive_avg_pool2d(feat_s, (t_H, t_H))
            target = feat_t
        else:
            source = feat_s
            target = F.adaptive_avg_pool2d(feat_t, (s_H, s_H))
        
        trans_feat_t=target
        
        # Channel Alignment
        trans_feat_s = getattr(self, 'transfer')(source)

        # Prediction via Teacher Classifier
        temp_feat = self.avg_pool(trans_feat_s)
        temp_feat = temp_feat.view(temp_feat.size(0), -1)
        pred_feat_s = cls_t(temp_feat)
        
        return trans_feat_s, trans_feat_t, pred_feat_s

有了这3个部分后,我们依次传入一个module_list(顺序为学生网络,特征转换网络,教师网络),之后就可以开始训练,如下:

def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, opt):
    """one epoch distillation"""
    # set modules as train()
    for module in module_list:
        module.train()
    # set teacher as eval()
    module_list[-1].eval()

    criterion_cls = criterion_list[0]
    criterion_div = criterion_list[1]
    criterion_kd = criterion_list[2]

    model_s = module_list[0]
    model_t = module_list[-1]

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    n_batch = len(train_loader) if opt.dali is None else (train_loader._size + opt.batch_size - 1) // opt.batch_size

    end = time.time()
    for idx, data in enumerate(train_loader):
        if opt.dali is None:
            if opt.distill in ['crd']:
                images, labels, index, contrast_idx = data
            else:
                images, labels = data
        else:
            images, labels = data[0]['data'], data[0]['label'].squeeze().long()
        
        if opt.distill == 'semckd' and images.shape[0] < opt.batch_size:
            continue

        if opt.gpu is not None:
            images = images.cuda(opt.gpu if opt.multiprocessing_distributed else 0, non_blocking=True)
        if torch.cuda.is_available():
            labels = labels.cuda(opt.gpu if opt.multiprocessing_distributed else 0, non_blocking=True)
            if opt.distill in ['crd']:
                index = index.cuda()
                contrast_idx = contrast_idx.cuda()

        # ===================forward=====================
        feat_s, logit_s = model_s(images, is_feat=True)
        with torch.no_grad():
            feat_t, logit_t = model_t(images, is_feat=True)
            feat_t = [f.detach() for f in feat_t]

        cls_t = model_t.module.get_feat_modules()[-1] if opt.multiprocessing_distributed else model_t.get_feat_modules()[-1]
        
        # cls + kl div
        loss_cls = criterion_cls(logit_s, labels)
        loss_div = criterion_div(logit_s, logit_t)
        
        # other kd loss
        trans_feat_s, trans_feat_t, pred_feat_s = module_list[1](feat_s[-2], feat_t[-2], cls_t)
        logit_s = pred_feat_s
        loss_kd = criterion_kd(trans_feat_s, trans_feat_t)
            
        loss = opt.cls * loss_cls + opt.div * loss_div + opt.beta * loss_kd
        losses.update(loss.item(), images.size(0))

        # ===================Metrics=====================
        metrics = accuracy(logit_s, labels, topk=(1, 5))
        top1.update(metrics[0].item(), images.size(0))
        top5.update(metrics[1].item(), images.size(0))
        batch_time.update(time.time() - end)

        # ===================backward=====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()        

        # print info
        if idx % opt.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'GPU {3}\t'
                  'Time: {batch_time.avg:.3f}\t'
                  'Loss {loss.avg:.4f}\t'
                  'Acc@1 {top1.avg:.3f}\t'
                  'Acc@5 {top5.avg:.3f}'.format(
                epoch, idx, n_batch, opt.gpu, loss=losses, top1=top1, top5=top5,
                batch_time=batch_time))
            sys.stdout.flush()

    return top1.avg, top5.avg, losses.avg

这里我给大家画了个示意图,还是挺清晰的,大家可以自己看看,值得注意的是,在进行学生网络的预测时我们需要使用到学生网络+特征转换网络一起的输出进行预测(也就是pred_feat_s)。
在这里插入图片描述

四. 总结

其实本文很简单,通过重用教师网络的分类器以及对学生网络教师网络之间特征的对齐从而能是的学生能向着更好的方向进行发展。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
数据免费知识蒸馏与软目标传输集合合成是一种通过利用现有数据集来提高深度神经网络的性能的方法。这种方法主要包括两个步骤:知识蒸馏和软目标传输集合合成。 首先,知识蒸馏是指将一个已经训练好的大型模型的知识转移到一个小型模型中。这样做的好处是,小型模型可以通过利用大型模型的知识来提高其性能。知识蒸馏的过程包括将大型模型的输出(一般是概率分布)作为目标分布,然后使用目标分布和小型模型的输出之间的交叉熵作为损失函数进行训练。通过这种方式,小型模型可以学习到大型模型的知识,并提高其性能。 其次,软目标传输集合合成是指通过合成新的目标数据集来进一步提高小型模型的性能。这是通过将已有数据集中的样本与大型模型的输出结合起来产生的。具体而言,对于每个样本,使用大型模型进行预测,并根据预测结果以及训练集中的标签来合成一个新的目标分布。然后,再次使用目标分布和小型模型的输出之间的交叉熵作为损失函数进行训练。通过这种方式,小型模型可以进一步学习到大型模型的知识,并提高其性能。 总之,数据免费知识蒸馏与软目标传输集合合成是一种提高深度神经网络性能的有效方法。通过利用已有的数据集和大型模型的知识,可以帮助小型模型更好地学习并提高其性能。这种方法在许多领域中都有广泛的应用,例如计算机视觉、自然语言处理等。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值