【论文阅读】NIPS 2023 On student-teacher deviations in distillation: does it pay to disobey?

文章探讨了知识蒸馏中学生模型与教师模型的匹配度问题,揭示了学生模型通过模仿教师并发展自己的信心模式,可能提高泛化能力。研究发现,偏离教师预测的模式有助于正则化和防止过拟合。此外,论文还提出在训练策略中适时调整,如从软标签到硬标签,以优化学生模型的性能。
摘要由CSDN通过智能技术生成


前言

关于蒸馏中的师生偏差:不服从是不是值得的?google的一篇文章,nips2023。主要研究了一个现象,就是学生模型和教师模型的表现不一致的问题。
在这里插入图片描述
发表地点:NIPS 2023;

论文下载链接:https://arxiv.org/pdf/2301.12923.pdf

一、背景

首先针对这个题目中的问题,作者提到在2021年的nips有一篇文章,来自纽约大学。它提出了一个问题,就是知识蒸馏是不是真的有用。这篇文章大概讲一下,就不详细介绍了。

目前有很多知识蒸馏相关的研究,但都主要集中于如何提高学生模型的泛化性,也就是对于新样本的适应能力。但是往往会忽视匹配度这个概念。也就是学生模型的预测和教师模型的预测的匹配程度。相比泛化性,匹配度会更好的反映学生模型蒸馏到了多少教师模型含有的知识。这篇文章对这两种概念做了一个比较详细的解释与区分,也指出获得良好的匹配度对学生模型来说比较困难的。

在这里有一个问题是为什么需要匹配度?之前的研究已经证明了知识蒸馏通常会提高学生模型的泛化能力,但是我们为什么还要关心这个学生模型的匹配度呢?

关于这个主要有三个解释。
首先是学生模型的性能和教师模型的性能往往有比较大的差距,提高匹配度是消除学生和教师性能差异最显而易见的方法。
其次良好的学生模型匹配度可以提高知识蒸馏的可解释性与可信赖性。
最后,将匹配度和泛化性解耦可以帮助更好的理解知识蒸馏是怎么工作的以及如何在各种应用程序中更好的利用它。

接下来就是怎么看一下学生模型的匹配度。这个图2023的文章也提到了。在这里它用了GAN生成的样本逐渐扩大cifar100的蒸馏数据集。
第一个是自蒸馏,蓝色的线代表了匹配度,绿色的线是学生的测试准确率,虚线是教师的准确率,发现随着数据集规模变大,这个匹配度升高的同时,说明学生和老师匹配的是更好的,但是学生的测试准确率反而降低了。这个就比较矛盾。
第二个作者使用3个ResNet-56网络的集成来作为教师模型,使用单个的ResNet-56网络来作为学生模型进行知识蒸馏。发现匹配度又增加了,然后学生的测试准确率略有增加。这个说明如果教师是一个大型模型的话,学生和教师的差距较大,这两个指标又是比较一致的。

二、贡献

现在就存在这样一个问题,偏离与泛化的悖论,就是学生模型从教师模型的预测中偏离,也就是没有那么匹配,反而可能在新的、未见过的数据上表现更好。2023这篇文章作者的目标是解释这一看似矛盾的行为。

他们提出,从教师的概率中的某些系统性偏离是可以改善学生模型的泛化能力的,因为这些偏离起到了正则化的作用。我们知道正则化是一种用来防止过拟合的技术,能够促进模型的简洁性。

在这里插入图片描述

大概讲一下这篇文章的主要贡献:
首先就是夸大的信心:学生和教师模型预测概率不匹配。学生越自信,老师越自信,学生越不自信,老师越不自信。
然后作者还发现自蒸馏(模型既是教师也是学生)可以夸大梯度下降的隐式偏差,也就是沿着顶部数据特征方向收敛得更快

总之,作者观察到,这个学生模型学习教师的时候,它不仅仅是复制教师的行为;它还发展了自己的信心模式,这些模式可能是更极端的。但是有助于学生模型更好地泛化(在新数据上表现更好),因为它就像一种训练形式,一定程度阻止了过拟合。

这些发现可以帮我们理解模型是如何学习的,并可能帮助探索一些更好的训练方法。

然后在这篇文章的设置里,作者提了学生在蒸馏的时候是模仿教师,也就是和教师模型输出求loss,没记入gt的那部分loss。这还有一个one hot的标签,这个标签是01标签,但是不是gt,是教师预测出的one hot

三、方法(a)

接下来简单介绍一下作者观察信心值的方法。这篇文章在训练集和测试集的每个样本(x,y)上分析教师和学生的预测差异,不是像其他工作那样进行整体的集合分析。

在这里插入图片描述

首先就是确定预测类别。这两个都是预测概率向量。第一个te是教师模型对于每个输入x的预测概率向量。第二个st是学生模型对应的预测概率向量。还有一个y,这个是由这个教师概率向量里最高的概率对应的类别确定的。也就是说这个教师和学生都是soft target,然后y是教师预测的类别,是一个类。

接下来选择比较的概率值,最初考虑的是比较这个真实类别y,也就是gt,然后作者提出学生模型是在模仿教师,不是直接预测真实类别,所以作者认为更有意义的比较是针对教师模型的预测类别的概率。

举个简单的例子的话就是正确类别是猫,教师输出给猫的概率是0.4,学生给猫的概率是0.3,原来是比较这两个输出。但是教师给狗的预测概率可能是0.5,预测类别是狗,作者提出的是比较教师和学生对狗的概率。

这样比较的话学生模型的目标是尽量接近教师的预测概率分布,看这个学生能不能准确捕获教师的预测模式。不仅仅是学生预测正确类别的能力。

接下来就是一个logit变换,为了可视化清晰,把概率值从0.1映射到负无穷到正无穷。然后输出的分别是横坐标和纵坐标。

在这里插入图片描述
这个颜色密度图更亮的区域代表数据点更集中。x等于y代表教师和学生的信心相等,点位于虚线下面,代表学生比教师更不自信,点位于线上面,代表学生比教师更自信。发现教师信心不高的时候,学生更不自信。教师已经有高信心的情况,学生也表现出更高的信心,说明蒸馏过程可能不仅仅是简单的知识传递,学生模型还加入了自己的解释,可能有助于提高泛化能力。

然后现在就发现学生会夸大教师的信心,但它在自蒸馏(学生模型和教师模型结构相同)的设置下表现的比教师模型还好。作者提出需要关注蒸馏过程中的另一种夸大行为。也就是刚才提到的第二个关于梯度下降的隐式偏差。

四、方法(b)

梯度下降在模型训练过程中会找到最小化预测错误的模型参数。这篇文章发现梯度下降更快的在数据的主要特征方向上收敛。也就是会优先调整对预测最有影响的特征。

在这里插入图片描述

作者用线性回归情况来进行分析。
x是数据的特征矩阵,在标准的线性回归中,权重在时间t下的更新由这个公式决定。At被定义为一个时间矩阵,随着时间增长,它会让沿着每个特征向量方向的权重收敛。这个I是一个单位矩阵,保证时间是0的时候At也是一个单位矩阵。后面这个指数项可以被分解为特征向量和特征值,其实可以被理解为e-tlamda,代表一个衰减因子,这个特征值lamda越大的话,随着时间这个权重就更新的越快,特征值越小的话更新的越慢。所以说在训练模型的时候,数据中更重要的特征,也就是有较大特征值的方向会得到更快的收敛。

在知识蒸馏里,学生模型的权重不仅收到当前时间矩阵At的影响,还受到教师模型训练结束时刻A矩阵的影响,所以这个学生模型的时间矩阵由两个A矩阵的乘积组成。他们俩的乘积会进一步放大每个A矩阵的影响,这个A矩阵本来就倾向于加快具有大的特征值的特征向量方向的收敛速度,所以学生就更加倾向于主要的特征向量。
在这里插入图片描述

定理:
相比教师模型,学生模型在次要特征方向上依赖得更少。这表明,尽管教师模型已经倾向于更快地沿着主要特征方向收敛,但在蒸馏过程中,学生模型这种趋势更强。作者随后也进行实验在手写数据集上也验证了这种收敛偏差。

🍿

总结一下这两个发现。
信心的夸大:研究展示了学生模型倾向于夸大教师模型的信心水平。这意味着如果教师对某个预测非常有信心,学生模型会展现出更高的信心水平;反过来也是。
特征空间收敛的夸大:在特征空间内,学生模型的收敛行为也被夸大。具体来说,学生模型更倾向于快速沿着数据主要变化的方向(即特征值较大的特征向量方向)收敛。
在这里插入图片描述

🍿

这张图片展示了夸大的偏差,也就是梯度下降沿着数据的主要特征收敛。也让这个学生模型相对于教师的信心水平变高了。同时它也有可能有助于模型在看不见的数据上更好的预测,也就是提高泛化能力。

结论就是学生模仿过头可能带来更好的结果。
在这里插入图片描述

这里还提到一个混杂因素,就是这个虚框。除了夸大的偏差,这种混杂因素也会影响学生模型的泛化。

五、实验

🍿

接下来是验证刚才理论的结果。第一张图的cifar100数据集部分标签有错误。教室模型对错误的one hot标签已经置信度很低了,但是学生模型的更低,我们发现基本完全在x=y这条线的下面。不管是在自蒸馏设置还是跨蒸馏设置里。
在这里插入图片描述

也就是说,观察到的标签是噪声的,教师模型的隐式偏差有助于去噪标签。

图验证了这一点,并发现错误标记的点对应于教师模型较低的置信度。这样的话,学生模型对于错误标记的数据点的拟合程度比教师模型更差。

自蒸馏的ResNet56模型比相同的教师模型提高了3%的性能。先前的工作认为,这是因为教师模型的隐式偏差导致其概率部分去噪,与one hot标签相比更准确。但这不能解释为什么复制教师概率的学生模型——能够超越教师。

作者解释说,就像刚才提到的一样,学生模型并不是简单地复制教师模型对于顶部特征向量的偏差,而是放大了这种偏差。这种放大的偏差提供了增强的去噪效果,让学生能超越教师模型。

🍿

作者又提出了一个问题是什么时候蒸馏会损害泛化能力。如果教室模型在训练数据上没有达到足够的top one准确率的时候,可能会对学生造成损害。也就是就算模型足够大,训练集表现不好的话,甚至可能也会对学生的泛化能力造成负面影响。

在这里插入图片描述

🍿

第一个证据是一个插值和非插值教师的实验。也就是能不能准确预测出训练集中每个数据点的标签。第一个情况是在整个数据集上插值的教室模型,100%的准确率,第二个是只在一半数据集上插值的教师。发现一半插值的甚至会损害学生性能。就是比one hot还要差。
在这里插入图片描述

🍿

第二个就是在非插值型教师的情况下,存在着一种最优的方式,也就是把蒸馏损失和one hot结合。把知识蒸馏逐渐向one hot转换,可以提高泛化能力。也就是教室不完美的时候,直接从软标签到硬标签可以提高准确率。

🍿

第三个证据讨论了当教师模型已经具有完美的top-1训练准确率时,从知识蒸馏切换one hot损失可能是不利的。

在我们的特征空间内,当教师有不完美的top one准确率,也就是非插值教师时,这可能意味着教师没有沿着数据的某些关键的特征方向充分收敛。被蒸馏放大的偏差会进一步限制沿这些方向的收敛,损害泛化。

总结

不完全复制老师的预测可能是有益的:当较小的模型(学生)向较大的模型(老师)学习时,如果学生不完全复制老师的预测可能会更好,尤其是当老师不太自信的时候。这可以帮助防止学生发现老师的错误,并使学生的学习过程更加稳健。

中途改变训练方法可能会有所帮助:如果教师模型没有为学生模型提供足够清晰的指令来学习正确的结果,那么在训练中途切换到更直接的学习方法可能会很有用。比如在训练中期实施 one-hot loss 可能是有利的。这就像老师从提示开始,然后直接给出答案,以确保学生走在正确的道路上。

多层模型:研究如何将所讨论的策略不仅应用于模型的一部分,而且应用于多个层,从而可能进一步改进学习过程。

夸大教学信号:探索如何在简单的设置中使学习信号更强,然后将这个想法应用于更复杂的蒸馏形式。

将这些想法应用到不同的学习场景中:将所学知识应用于系统中不同类型的模型训练中,例如半监督学习(模型使用标记和未标记数据的混合进行学习)对项目进行排名而不是对其进行分类,或者在学习不同抽象级别的表示的模型中。

  • 25
    点赞
  • 38
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值