Multi-Exit Network 实现细节

一、动机:
(1)在现实中对图像分类难度不一,采用一个固定的框架对图片进行分类时有时不够灵活。
主要思想就是在一个网络中有多个分类出口(创新点),对于简单图像可以直接从前面某个分类出口得到结果,而难分类的网络可能要到网络后面的某一层才能得到可靠的结果.
(2)动态网络的early-exit可以减少计算量,同时latter-exit也不必重复进行浅层backbone的inference,加快速度 。
主要应用领域:图像分类,语义分割
数据集:MNIST, CIFAR 10,CIFAR 100, ImageNet, COCO, PASCAL VOC

二、exit branch是一个分类器,输出每个类别的概率。
(1)数量:2~n个exit
(2)结构:卷积层、全连接层、池化层+softmax
(3)位置:位于每个block之后。
一些文献指出,在backbone中稍后面的放置一个exit并不一定会提高该分支的整体准确性,可能在更前的位置效果更好。
一般来说,exit位于“自然块”之后,例如concatenation layer,residual connection, dense block之后,它们的性能更好

全连接层作为branch,参考:https://github.com/ArchipLab-LinfengZhang/pytorch-scalable-neural-networks
max pool + avg pool +全连接作为branch,参考:https://github.com/yigitcankaya/Shallow-Deep-Networks
conv+conv+avg pool+全连接作为branch,参考:https://github.com/kalviny/MSDNet-PyTorch

三、多级网络的训练方法:
(1)end-to-end, one-stage, 联合训练所有branch的loss,每个branch的loss有一个超参数权重。问题:该方法对branch位置敏感,一个branch的acc可能受其他branch影响
(2)layer-wise, 分级训练, n-stage,第一次训练模型直到第一个branch的部分,第二次冻结之前的权重,训练模型剩余部分直到第二个branch的部分,依次进行。
(3)classifier-wise, n-stage, 首先训练backbone+final exit,然后冻结backbone,单独训练每个branch
(4)在上述三种基础上,加入知识蒸馏,其中许多文献采用自蒸馏方法,采用最后一级branch作为teacher蒸馏前面的branch
(5)可能存在的其他创新

1、End-to-End训练,并加入自蒸馏:参考 https://github.com/ArchipLab-LinfengZhang/pytorch-scalable-neural-networks
作者采用三个损失
(1)分类损失,交叉熵。
(2)自蒸馏损失,用final exit蒸馏前面的exit,交叉熵。
(3)特征损失,最后一层ResBlock与前面的ResBlock输出的特征损失,计算特征误差的平方和。

for epoch in range(args.epoch):
    net.train()
    sum_loss = 0.0
    correct = 0.0
    total = 0.0
    for i, data in enumerate(trainloader, 0):
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        outputs, feature_loss = net(inputs)  #outputs包含四个branch的输出 [out4, out3, out2, out1], feature_loss是最后的branch输出特征和前面的branch特征的平方差之和

        ensemble = sum(outputs[:-1])/len(outputs)  #ensemble是所有branch的平均输出
        ensemble.detach_()
        ensemble.requires_grad = False

        #   compute loss
        loss = torch.FloatTensor([0.]).to(device)

        #最后一级branch的loss
        loss += criterion(outputs[0], labels)

        #最后一级branch作为teacher
        teacher_output = outputs[0].detach()
        teacher_output.requires_grad = False

        #用final branch蒸馏浅层branch   并计算每个branch的分类loss
        for index in range(1, len(outputs)):
            loss += CrossEntropy(outputs[index], teacher_output) * args.lambda_KD * 9
            loss += criterion(outputs[index], labels) * (1 - args.lambda_KD)

        # 特征差异损失
        if args.lambda_KD != 0:
            loss += feature_loss * 5e-7

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total += float(labels.size(0))
        sum_loss += loss.item()

其中feature loss计算如下:

teacher_feature = out4_feature.detach()
feature_loss = ((teacher_feature - out3_feature)**2 + (teacher_feature - out2_feature)**2 +\
(teacher_feature - out1_feature)**2).sum()

其中蒸馏损失计算如下:

def CrossEntropy(outputs, targets):
    log_softmax_outputs = F.log_softmax(outputs/3.0, dim=1)
    softmax_targets = F.softmax(targets/3.0, dim=1)
    return -(log_softmax_outputs * softmax_targets).sum(dim=1).mean()

2、End-to-End训练,并加入自蒸馏:参考 https://github.com/ArchipLab-LinfengZhang/pytorch-self-distillation-final
作者采用三个损失
(1)分类损失,交叉熵。
(2)自蒸馏损失,用final exit蒸馏前面的exit,交叉熵。
(3)特征损失,最后final branch的特征和前面的branch的特征进行蒸馏,计算特征误差的L2范数

# compute loss
loss = torch.FloatTensor([0.]).to(device)

# final branch的分类损失
loss += criterion(outputs[0], labels)

# final branch的输出logit和特征
teacher_output = outputs[0].detach()
teacher_feature = outputs_feature[0].detach()

# 浅层branch的损失
for index in range(1, len(outputs)):
    # logit蒸馏损失
    loss += CrossEntropy(outputs[index], teacher_output) * args.loss_coefficient
	# 分类损失
    loss += criterion(outputs[index], labels) * (1 - args.loss_coefficient)
	# 特征蒸馏损失  adaptation_layers是多层全连接
    if index != 1:
        loss += torch.dist(net.adaptation_layers[index-1](outputs_feature[index]), teacher_feature) * \
                args.feature_loss_coefficient
        #   the feature distillation loss will not be applied to the shallowest classifier

3、End-to-End训练,加入权重,参考:https://github.com/yigitcankaya/Shallow-Deep-Networks
作者认为预训练backbone+final exit会影响前面exit的性能,因此进行end-to-end training, 并给每个exit加权ti,
ti的值逐渐增大并不超过对应的Ci(15%,30%,45%,60%,75%,90%)

四、inference
多级网络在inference时,在forward计算时加入了置信度判断条件。
一般将当前classifier输出最大概率作为置信度,若超过当前阈值,则退出。
以下代码参考自https://github.com/yigitcankaya/Shallow-Deep-Networks

training时的forward:

    def forward(self, x):
        outputs = []
        fwd = self.init_conv(x)  #init_conv 是开头的结构
        for layer in self.layers:  #layers是backbone的基本结构
            fwd, is_output, output = layer(fwd)
            if is_output:
                outputs.append(output)
        fwd = self.end_layers(fwd)  #end_layers 是最后一级classifier的结构
        outputs.append(fwd)

        return outputs

inference时的forward:

    def early_exit(self, x):
        confidences = []
        outputs = []

        fwd = self.init_conv(x)
        output_id = 0
        for layer in self.layers:
            fwd, is_output, output = layer(fwd)

            if is_output:
                outputs.append(output)
                softmax = nn.functional.softmax(output[0], dim=0)
                
				#当前classifier输出的最大概率作为置信度
                confidence = torch.max(softmax)
                confidences.append(confidence)
            	
				#判断是否要early exit
                if confidence >= self.confidence_threshold:
                    is_early = True
                    return output, output_id, is_early
                
                output_id += is_output

        output = self.end_layers(fwd)
        outputs.append(output)

        softmax = nn.functional.softmax(output[0], dim=0)
        confidence = torch.max(softmax)
        confidences.append(confidence)
        max_confidence_output = np.argmax(confidences)
        is_early = False
        return outputs[max_confidence_output], max_confidence_output, is_early

五、置信度与退出阈值
置信度:大部分文章将输出概率作为置信度。
如果所有branch输出都不满足阈值条件,用所有branch平均输出作为最终输出,或最后一级作为输出。
退出阈值:退出阈值的设置,一些文章采用预先定义阈值的方法。阈值设置一般会宗哥和acc和latency考虑,trade-off
方法一
(1)预先定义阈值,阈值随网络深度增加而递减
(2)将branch输出最大概率作为置信度
参考:https://github.com/ArchipLab-LinfengZhang/pytorch-scalable-neural-networks

def judge(tensor, c):  #一个branch输出的logits  c表示第c个branch
    dic = {0: 0.98, 1: 0.97, 2: 0.98, 3: 0.95}
    maxium = torch.max(tensor)#最大输出概率
    if float(maxium) > dic[c]:
        return True
    else:
        return False

方法二
(1)预先定义阈值,所有branch的阈值相同,采用grid search搜索最佳阈值,
(2)将branch输出最大概率作为置信度
参考:https://github.com/yigitcankaya/Shallow-Deep-Networks


print('Calibrate confidence_thresholds')
confidence_thresholds = [0.1, 0.15, 0.25, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99, 0.999] # 搜索阈值
model.forward = model.early_exit #修改forward为inference方法
for threshold in confidence_thresholds:
    print(threshold)
    model.confidence_threshold = threshold
    top1_test, top5_test, early_exit_counts, non_conf_exit_counts, total_time = mf.sdn_test_early_exits(sdn_model, one_batch_dataset.test_loader, device)

六、评价指标与SOTA结果
评价指标
(1)每个branch单独计算acc。
(2)所有branch的平均值作为输出计算acc(ensamble)。
(3)比较final branch的acc请添加图片描述
请添加图片描述
请添加图片描述
七、创新点思考
基于多级网络的创新点
一些文献在multi-exit基础上提出了创新:
(1)课程式学习 Curriclum Learning:
由简单到困难来学习课程(在机器学习里就是容易学习的样本和不容易学习的样本),这样容易使模型找到更好的局部最优,同时加快训练的速度。
(2)自蒸馏:采用最后一级branch作为teacher蒸馏前面的branch,用输出特征/概率蒸馏。
(3)多尺度:基于浅层次的分类不能获取图像的高层语义特征的问题,采用多尺度方法,浅层网络即可获取高层特征。
(4)密集连接:基于浅层次的分类器对于后面分类器精度影响的问题,采用密集连接,反向传播的时候,每个分类都可以通过shortcut 对某一层的产生直接的影响,让权重向对每个分类效果更好的方向更新。
(5)梯度平衡:如果end-to-end中损失仅仅采用所有branch的加权求和,由于网络的重叠会产生梯度不平衡的问题。随着branch增加,梯度会逐渐增大使得训练不稳定。
为此,文章提出了gradient equilibrium(GE),梯度平衡。
(6)结构搜索:加入了NAS来寻找最优结构。
(7)inference策略的创新:用熵。
(8)分布式训练。

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值