FitNets: Hints for Thin Deep Nets 原理与代码解析

paper:FitNets: Hints for Thin Deep Nets

code:https://github.com/megvii-research/mdistiller/blob/master/configs/cifar100/fitnet.yaml

背景

本篇文章之前的卷积神经网络压缩相关的工作都专注于将一个教师网络或多个集成网络压缩成相似宽度和深度的网络或是更浅更宽的网络,没有很好的利用到深度的优点。网络的深度是表示学习的一个基本方面,它能促进特征的重用,在更深的网络层上有更抽象的表示。但网络越深训练起来也越困难,因为深度网络是由连续的非线性组成的,因此产生高度非凸和非线性的函数。

本文的创新点

本文旨在通过利用深度来解决网络压缩的问题,提出了一种新的方法来训练更窄以及更深的网络,称为FitNets。该方法受最近提出的知识蒸馏(KD)的启发,作者引入教师网络隐藏层的intermediate-level hints来引导学生网络的训练,即希望学生网络(FitNet)学习教师网络的中间层表示。

方法介绍

首先介绍知识蒸馏的原理,教师网络 \(T\) 的softmax输出 \(P_{T}=softmax(a_{T})\),其中 \(a_{T}\) 是教师网络pre-softamx的激活向量,\(S\) 表示一个学生网络,网络参数为 \(\mathbf{W_{S}} \),输出概率 \(P_{S}=softmax(a_{S})\),学生网络同时在教师网络的输出 \(P_{T}\) 和标签 \(y_{true}\) 的监督下进行训练,为了提供更多的信息还引入了温度 \(\tau >1\),如下

学生网络的优化损失函数如下

其中 \(\mathcal{H} \) 表示交叉熵损失,\(\lambda\) 是平衡两个交叉熵损失的可调权重系数。

为了帮助比教师网络更深的学生网络FitNets的训练,作者引入了来自教师网络的 hints。hint是教师隐藏层的输出用来引导学生网络的学习过程。同样的,选择学生网络的一个隐藏层称为 guided layer,来学习教师网络的hint layer。注意hint是正则化的一种形式,因此需要仔细地选择hint和guided层学生网络才不会被过度正则化。hint层选的越深,网络的灵活性就越小,学生网络也更容易被过度正则化。因此文中hint和guided层作者都选择了对应网络的中间层。

因为教师网络通常比学生网络更宽,因此对应的hint层和guided层的spatial size可能不一样,作者引入了一个regressor,来使guided层的大小与hint层一致。然后通过优化如下的损失函数来训练学生网络从第一层到guided层以及regressor的参数

其中 \(u_{h}\) 和 \(v_{g}\) 分别表示教师和学生网络从第一层到hint和guided层的网络结构,\(\mathbf{W_{Hint}} \) 和 \(\mathbf{W_{Guided}} \) 分别表示对应的网路参数,\(r\) 和 \(\mathbf{W_{r}} \) 表示guided层后的regressor和对应的参数。

在hint层和guided层都是卷积层输出的特征图的情况下,使用全连接层作为regressor会显著增加参数量和内存消耗,文中作者选用了一个卷积层作为regressor,假设hint层的大小为 \(N_{h,1}\times N_{h,2}\),guided层的大小为 \(N_{g,1}\times N_{g,2}\),则regressor卷积核大小 \(k_{1}\times k_{2}\) 满足 \(N_{g,i}-k_{i}+1=N_{h,i}\),\(i\in\left \{ 1,2 \right \} \)。

FitNet Stage-wise Traning

FitNet采用阶段式的训练方式,如下图所示

给定一个训练好的教师模型和一个随机初始化的学生模型FitNet,如图1(a)所示。在学生网络guided层后添加一个regressor层,按式(3)训练学生网络第一层到guided层和regressor,如图1(b)所示。然后在第一阶段预训练参数的基础上,按式(2)训练整个学生网络,如图1(c)所示。

实现代码

import torch
import torch.nn as nn
import torch.nn.functional as F

from ._base import Distiller


class ConvReg(nn.Module):
    """Convolutional regression"""

    def __init__(self, s_shape, t_shape, use_relu=True):
        super(ConvReg, self).__init__()
        self.use_relu = use_relu
        s_N, s_C, s_H, s_W = s_shape
        t_N, t_C, t_H, t_W = t_shape
        if s_H == 2 * t_H:
            self.conv = nn.Conv2d(s_C, t_C, kernel_size=3, stride=2, padding=1)
        elif s_H * 2 == t_H:
            self.conv = nn.ConvTranspose2d(s_C, t_C, kernel_size=4, stride=2, padding=1)
        elif s_H >= t_H:
            self.conv = nn.Conv2d(s_C, t_C, kernel_size=(1 + s_H - t_H, 1 + s_W - t_W))
        else:
            raise NotImplemented("student size {}, teacher size {}".format(s_H, t_H))
        self.bn = nn.BatchNorm2d(t_C)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        if self.use_relu:
            return self.relu(self.bn(x))
        else:
            return self.bn(x)


def get_feat_shapes(student, teacher, input_size):
    data = torch.randn(1, 3, *input_size)
    with torch.no_grad():
        _, feat_s = student(data)
        _, feat_t = teacher(data)
    feat_s_shapes = [f.shape for f in feat_s["feats"]]
    feat_t_shapes = [f.shape for f in feat_t["feats"]]
    return feat_s_shapes, feat_t_shapes


class FitNet(Distiller):
    """FitNets: Hints for Thin Deep Nets"""

    def __init__(self, student, teacher, cfg):
        super(FitNet, self).__init__(student, teacher)
        self.ce_loss_weight = cfg.FITNET.LOSS.CE_WEIGHT
        self.feat_loss_weight = cfg.FITNET.LOSS.FEAT_WEIGHT
        self.hint_layer = cfg.FITNET.HINT_LAYER
        feat_s_shapes, feat_t_shapes = get_feat_shapes(
            self.student, self.teacher, cfg.FITNET.INPUT_SIZE
        )
        # [torch.Size([1, 32, 32, 32]), torch.Size([1, 64, 32, 32]), torch.Size([1, 128, 16, 16]), torch.Size([1, 256, 8, 8])]
        # [torch.Size([1, 32, 32, 32]), torch.Size([1, 64, 32, 32]), torch.Size([1, 128, 16, 16]), torch.Size([1, 256, 8, 8])]
        self.conv_reg = ConvReg(
            feat_s_shapes[self.hint_layer], feat_t_shapes[self.hint_layer]
        )

    def get_learnable_parameters(self):
        # for k, v in self.conv_reg.named_parameters():
        #     print(k, v.shape)
        # conv.weight, (128,128,1,1)
        # conv.bias, (128)
        # bn.weight, (128)
        # bn.bias, (128)
        return super().get_learnable_parameters() + list(self.conv_reg.parameters())

    def get_extra_parameters(self):
        num_p = 0
        for p in self.conv_reg.parameters():
            num_p += p.numel()
        return num_p

    def forward_train(self, image, target, **kwargs):
        logits_student, feature_student = self.student(image)
        with torch.no_grad():
            _, feature_teacher = self.teacher(image)

        # losses
        loss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target)
        f_s = self.conv_reg(feature_student["feats"][self.hint_layer])
        loss_feat = self.feat_loss_weight * F.mse_loss(
            f_s, feature_teacher["feats"][self.hint_layer]
        )
        losses_dict = {
            "loss_ce": loss_ce,
            "loss_kd": loss_feat,
        }
        return logits_student, losses_dict

  • 5
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

00000cj

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值