NeurIPS 2022 | 知识蒸馏中如何让“大教师网络”也教得好?

35e1a9fe9ea8525468457fc259b787e9.gif

©作者 | 李新春

单位 | 南京大学

研究方向 | 知识蒸馏

本文介绍一篇发表在机器学习顶会 NeurIPS 2022 (CCF-A 类会议)的论文《Asymmetric Temperature Scaling Makes Larger Networks Teach Well Again》。该工作的研究内容为知识蒸馏(Knowledge Distillation),是与华为诺亚联合实验室共同研究产出的一篇工作。本文分为以下几个部分对该工作进行介绍:

  • 文章链接

  • 代码链接

  • 研究背景

  • 提出的方法

  • 实验效果

  • 投稿历程


248a375194f70967371056d35359abfd.png

论文题目:

Asymmetric Temperature Scaling Makes Larger Networks Teach Well Again

文章来源:

NeurIPS 2022

论文链接:

https://arxiv.org/abs/2210.04427

代码链接:

https://github.com/lxcnju/ATS-LargeKD

https://gitee.com/lxcnju/ats-mindspore

作者主页:

http://www.lamda.nju.edu.cn/lixc/

8ab378fefd3e0a942e2829663ca00bb5.png


研究背景

知识蒸馏(Knowledge Distillation)可以将大(强)模型的能力传递给轻量(弱)模型,其基本形式如下:

2ddb8c2323912bb0188e2bb927e76d5c.png

▲ 知识蒸馏示意图

其基本步骤为:1)在训练集上训练一个大教师网络,或者拿现有的当作教师网络;2)使用图示的损失去指导学生网络进行训练。损失包括两部分:正常分类的损失和知识蒸馏损失。前者是 hard-label,后者是 soft-label。引入后者的目的是因为学生直接学习 hard-label 太困难了,因此期望学生能够模仿教师的 soft 输出,从而把握类别之间的相似度,从而更好地学习。

值得注意的是:知识蒸馏损失里面的温度系数 Temperature 很重要!如果 很小,那么教师的输出结果像 hard-label,导致和正常分类损失相比没有什么额外的信息;如果 很大,那么教师的输出结果像 uniform-label,类别之间的差异性就没有了,仅仅起到了一个 label smoothing 的作用。

5dcb0d2b00722645d57998a658f853a4.png

▲ 知识蒸馏中温度系数的作用

普遍的认知是越好的教师教学生教地越好。然而实际上,2019 年有学者 [Jang Hyun Cho, 2019] 指出:大神经网络不一定教地好!

引用下面的一个示意图(来自 [Seyed Iman Mirzadeh, 2020]),随着 teacher size 逐渐变大,教师的准确率越来越高(红色的 teacher accuracy),但是其教的学生的准确率先变高再变低(蓝色的 student accuracy)。现有的工作都将这个奇怪的现象归因于”大教师网络“和”小学生网络“之间的容量差异(capacity gap),但是没有形象地指出这种差异为何出现。

436b2f9204039eec5f8f32e4772ecde1.png

▲ 大神经网络不一定教得好

因此,本文的研究内容就是:为什么大神经网络不一定教地好,有没有什么简单的办法让大神经网络教地好?

7066906ef4cdfb05694eeb4b084e4672.png

▲ 文章的研究内容

cbb9d39180bfa576464526674cce6a08.png


提出的方法

本文最直接的猜测起源于下面的式子:

602707c570aebf5533985174943bd88c.png

▲ 大教师网络和小教师网络在教同一个学生网络的区别

也就是说,在遍历所有可能温度系数的情况下,相比较于大教师网络,小教师网络更容易给出质量更好的指导信息,即

首先,文章通过一些观察实验发现:大教师网络更容易给出置信度较高的预测,包括两个方面。其一,正确类别的 logit 可能更大;其二,错误类别 logits 之间差异更小。本文称神经网络最后一层给出的类别预测得分称为 logits。

b7b68540ca3343428584d9a1d531607c.png

▲ 大神经网络和小神经网络的输出 logits 的分布

具体地,在 CIFAR-100 和 CIFAR-10 上训练 ResNet14/44/110 和 WRN28-1/4/8,统计神经网络输出的 logits 的如下指标:

  • 每个样本正确类别的 logit,记作  

  • 每个样本错误类别 logits 间的方差,记作  

可以看出,在 CIFAR-100 上,ResNet110 很明显给出了更大的 ,在 CIFAR-10 上 WideResNet28-8 给出了更小的 。举例而言,给定五个类别,第一个为正确类别,小教师网络和大教师网络给出的 logits 大概如下:

6711ea442215c3aaec7dff732b2bef26.png

▲ 大教师和小教师给出的logits示例

这就是最基本的现象,也是整个工作的启发点:大神经网络更为置信,给出的 target logit 更大,或者 wrong logits 差异更小!

那么我们不妨设想两个极端:

  • 如果 target logit 非常大,那么无论用什么温度系数对教师的输出进行 softmax,最后得到的 都为 one-hot 形式;

  • 如果 wrong logits 之间差异很小,就假设都一样,那么无论用什么温度系数对教师的输出进行 softmax,最后得到的 在错误类别之间都无法提供差异化信息。

也就是说,大教师网络的高置信度导致:无论在什么样子的温度系数下,其给出的指导信息(即:)都很难具有足够有效的信息!这里足够有效的信息如何定义呢?本文将其定义为,错误类别之间的概率值的方差!

df7ae6dbbc2751d9832c446a707195c8.png

▲ 根据现象得到的猜测

因此本文的猜测为:大神经网络不能教地好的原因是无论使用怎样的温度系数,都难以使得错误类别概率“错落有致”。

为了从理论上去推导验证,本文将知识蒸馏分为三个部分:

eb908e1fecd950692af926e8a1d4205d.png

▲ 知识蒸馏分解

分别包括:1)Correct Guidance,类似于 hard-label 的 one-hot 标签;2)Smooth Regularization,错误类别的平均概率值,类似于 label smoothing;3)Class Discriminability,错误类别之间的差异,可以用方差来度量,错误类别差异越大,教师提供的指导信息越多!

接下来是理论分析,先定义一些符号和公式:

ccb2994dc29b09e941b70614a575b33d.png

▲ 一些基本的符号

理论分析:

982a9b5a2cf8a7fcc72e3b4d3464bbf8.png

事实上:随着 不断增大,得到的 的熵越来越大,即越来越均匀。本研究证明了:随着 不断增大,得到的 元素之间的方差也会越来越小。

7c0c0dbb0b6fc24057dc5b2b31c0a16f.png

在正确类别 logit 最大情况下:随着 不断增大,错误类别概率的均值 会逐渐增大。

最重要的等式为:

325204aa8aa87b3bfac6de145b4b4548.png

其中 DA、IV、DV 分别解释如下:

  • Inherent Variance:错误类别 logits 经过 softmax 之后得到的类别概率分布的方差;

  • Derived Average:所有类别 logits 经过 softmax 之后得到的错误类别概率的平均值;

  • Derived Variance:所有类别 logits 经过 softmax 之后得到的错误类别概率的方差。

针对某一个样本的计算示意图如下(SF 代表 Softmax):

4709cac9eaf7891cfdd6480da7b0d830.png

▲ DA、IV、DV关系示意图

利用该公式解释为什么大教师网络教不好:

8f35d5da0614acae8f387fc75184a935.png

翻译为中文为:

  • 大教师网络给的正确类别 logit 的值很大,导致 DA 变小;

  • 大教室网络给的错误类别 logit 的差异很小,导致 IV 变小。

最终都会导致 DV 变小,即:大教师网络的 DV 很小,传统温度缩放下很难让错误类别的概率“错落有致”。

提出的方法为 Asymmetric Temperature Scaling(ATS),针对正确/错误类别施加较大/较小的温度系数:

  • 大教师网络给的正确类别 logit 的值很大,用较大的 增大 DA;

  • 大教室网络给的错误类别 logit 的差异很小,用较小的 增大 IV。

结论:ATS 可以使得大教师网络的 DV 变大让错误类别的概率“错落有致”。


0883b4eb4f1de64decfcd65394aaa59e.png


实验结果

实验设置和结果就不详细介绍了,有兴趣的可以看文章。下面就简单贴一下结果:

40ce3a373a227ddb4138c2e5b252b01b.png

b4e91603c5df14b398aa64127446b84a.png

fa4672dc6aacfb2b6f3ba1105079ee68.png

c2191dc868d4f2669ba4126d26e6864a.png

4931af8df622b6f3c98300dfdaaf6c62.png

2d6ef77fc229bf8a5e870324c322532e.png


投稿历程

到此,本文的基本方法都介绍完了,是一个非常简单的改进。研究设计的过程中也充满了乐趣,主要包括三个过程:

  1. 发现大神级网络和小神经网络输出的结果具有一些差异,兴奋值 ++;

  2. 发现可以将知识蒸馏的损失分解为三部分,特别是 class discriminability 的定义很有意思,兴奋值 ++;

  3. 发现可以用公式解释大教师神经网络的 DV 很小,兴奋值 ++;

  4. 发现可以提出一个非常简单的 ATS 来使得大教师教地更好,兴奋值 ++。

该工作完成于 2021.1 月份左右,在新年前几天完成的,满怀期待投稿了 ICML 2022。很不幸的是被拒了,个人感觉是在边缘,因为审稿人给的意见都没有特别严重的,主要是一些行文思路和概念没有解释清楚。

于是完善了之后转投了 NeurIPS,得分为 2 (Strong Reject),5(Borderline Accept),6 (Weak Accept)。看到审稿意见本想放弃,但仔细一看给 2 分的貌似只是针对我们公式符号的不合理性进行了攻击,感觉还是有希望的。于是修改了符号,提交了 rebuttal revision。审稿人然后就将分数改为 6。最终得分为 666。

更多阅读

81d075b376555c567b281e6466295779.png

9dbbcc7820b62be9c3795f3b5ce30b12.png

687fc1a6731e2bdc91ba99a3a2e61c40.png

c735300b283da46a7e299d46120bdea5.gif

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算

📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿

e6b8c253565fbe73245f3eedbe6a7e56.png

△长按添加PaperWeekly小编

🔍

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

·

b95364091da9e369947b109866013fae.jpeg

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值