【论文阅读】Born-Again Neural Network

系列文章目录

Dataset Shift



【阅读笔记】【AI测试】Born-Again Neural Networks

阅读笔记,非全文翻译


Abstract

  1. 知识蒸馏(Knowlodge Distillation,KD)意在使学生模型学到教师模型的知识,从而有更紧凑的体积,同时不牺牲太多的性能
  2. 文章从新的角度研究了KD方法:不压缩模型,而是对学生模型进行和老师模型一样的参数化训练(parameterized training)
  3. Born-Again Networks(BANs)在CV和NLP上都表现得比教师显著更好
  4. 基于DenseNets的BANs实验在CIFAR10上SOTA3.5%验证误差,CIFAR100上SOTA15.5%验证误差。
  5. 额外的实验,探索了两个distillation objectives:①CWTMDKPP,两种方法都阐述了KD的组成部分,表明了教师输出对预测类和非预测类的影响。

1 Introduction

  1. Born-Again Trees设想:学习一棵与多数预测器性能相近的单树。
  2. KD方法中,虽然直接根据数据进行训练时,学生无法与教师相匹配,但蒸馏过程使学生更接近于与教师的预测能力相匹配
  3. 在将知识从教师传授给能力相同的学生的实验中,意外地发现学生大大超过了他们的老师。
  4. 提出简单的再训练模式:在教师模型收敛后,我们初始化一个学生,并用两个目标来来训练它:①预测正确的标签;②输出的分布与教师尽可能接近。 这学生模型就称之为BANs。LSTM、DenseNet,ResNet的BAN都比它们验证误差小。
  5. KD引入的梯度包含两个项:
  6. 弱老师也能教出强学生

2 相关文献

2.1 知识蒸馏

  1. 神经网络的可解释性或透明性,依然是很模糊的
  2. 与本文类似研究(Yim et al. 2017)表明,把KD应用到两个架构一样的模型上,学生模型训练的更快,且准确率更高。

与Yim研究的关键区别
3.

2.2 Residual and Densely Connected Neural Networks

3 Born-Again Networks

在这里插入图片描述

  1. 泛化误差可以通过改变损失函数来减小,经典方法:正则化来减小模型复杂度

3.1 Sequence of Teaching Selves Born-Again Networks Ensemble

在这里插入图片描述
在这里插入图片描述

3.2 Dark Knowledge Under the Light

KD方法的可行性在于,分类错误的样本的output logits中隐藏着暗知识。softmax层的输出,除了正例之外,负标签也带有大量的信息,比如某些负标签对应的概率远远大于其他负标签。而在传统的训练过程(hard target)中,所有负标签都被统一对待。也就是说,KD的训练方式使得每个样本给Net-S带来的信息量大于传统的训练方式。

3.3 BANs对深度和宽度变化的稳定性

4 实验

4.1 CIFAR-10/100

Baselines
  • 不同heights和growth factors的DenseNet
  • 然后使用了classical ResNet
  • 然后构建了Wide-ResNet and bottleneck-ResNet学生用于实验
BAN without 暗知识
  • CWTM设置,排除了非最大值的logits的影响,相当于hard softmax。用teacher输出的最大值作为学生loss的权重。
  • DKPP设置,除了最大值的logits,其他logits随机排列。
  • 两种设置都改变了输出的协方差,如果发生改进,不能归因于经典的暗知识解释
BAN-Resnet with DenseNet 老师
  • 学生和老师共享第一层和最后一层(【问题】结构还是参数?)
BAN-DenseNet with ResNet 老师
  • 测试了:一个弱的Resnet老师能否教出DenseNet-90-60学生。

4.2 Penn Tree Bank(数据集)

  • 除了应用到CV,还把BAN方法应用到NLP上,在PTB数据集上。
  • 单层LSTM,with 1500 units,65% dropout,40 epochs,使用SGD,batchsize 32,自适应学习率初始1,调节率0.25
  • 一个小型模型,包括:一个卷积层,highway层,一个2层的LSTM(参考了CNN-LSTM),使用SGD,40 epochs,batchsize20,学习率(2,0.5)
  • 两个模型【问题】没懂这里的话

5 结果

  • BAN学生模型几乎提升了所有配置下的教师表现

5.1 CIFAR-10

在这里插入图片描述

5.2 CIFAR-100

BAN-DenseNet 和 BAN-ResNet
  • The improvement of fully removing the label supervision is systematic across modality,
    在这里插入图片描述
  • 疑问:为什么这里BAN(只用了软目标)效果最好,而Hinton说加入hard loss效果更好?
  • 疑问:这里的Ensemble是哪几个的BAN的Ensemble?没提
Sequence of Teaching Selves
  • 几代之后就收敛了
  • BAN-3在DenseNet-80-80上获得了非集成模型、没有用shake-shake正则的SOTA
    在这里插入图片描述
BAN-Ensemble
  • Ens-3-DenseNet-80-120,14.9%误差,150M参数,是有过报告的最低误差的集成模型
  • BAN-3-DenseNet-112-33,16.59%误差,6.3M参数,效果也很好
Effect on non-argmax Logits
  • CWTM弱提升,有的提升了,有的减弱了;DKPP系统性提升,误差都降低了
  • DKPP表明:输出分布的高阶矩(higher order moments,不知道是什么)对置换过程有不变性,依然可以提升泛化能力
  • CWTM表明:完全删除错误信息依然可以改善教师模型,【问题】这里似与minibatch有关?
DenseNet to modified DenseNet students
  • Table 4看出,DenseNet学生,在层数变化的方面很有鲁棒性;0.5*Depth的模型都有提升(16.95%误差)
  • 最大的不稳定性发生在改变压缩率上(compression rate,Compr),compr越小,保存的参数越多,acc损失越小
    在这里插入图片描述
DenseNet Teacher to ResNet Student
  • Table 5中发现,Wide-ResNet、Pre-ResNet的表现,比它们的DenseNet老师(这些老师的每一个stage的输出形状都与Wide-ResNet、Pre-ResNet相同),比传统的ResNet,比它们的Baseline都要好。【问题】这里的Baseline是什么
ResNet Teacher to DenseNet Students
  • Table 3最右列,只用简单的标签(类似于hard prediction)训练,就能得出更好的学生模型

5.3 PTB数据集

  • 没有使用SOTA的LSTM训练方法,也没有使用最新的KD方法,仍然观察到了语言模型在经过BAN处理后,在验证集和测试集上复杂度(perplexity)都降低了。
  • 不像CNN在CIFAR上的表现,作者发现LSTM只在使用BAN+L(使用教师的输出和标签,两种组合起来计算Loss)时有效果。可能是由于,CNN教师在CIFAR10上表现已经接近100%了,但LSTM教师在PTB上远远没有达到最优。

6 讨论

略。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值