paper:VanillaKD: Revisit the Power of Vanilla Knowledge Distillation from Small Scale to Large Scale
official implementation: GitHub - Hao840/vanillaKD: PyTorch code and checkpoints release for VanillaKD: https://arxiv.org/abs/2305.15781
前言
现有的大多数蒸馏方法大都在小规模数据集如CIFAR以及小尺度的模型如Res34-Res18上进行验证的,然而下游任务大多需要backbone在大规模的数据集上(如ImageNet)进行预训练从而得到sota的表现,只在小数据集上探索蒸馏方法在实际应用中可能无法提供一个全面的理解。考虑到大规模benchmarks的可用性和大模型的容量,我们无法确定之前的方法在更复杂的方案中是否仍然有效,具体包括更强的训练方案、不同的模型容量以及更大的数据规模。
本文对这个问题进行了深入的研究,并深入研究了决定蒸馏性能的关键因素。本文指出了当前知识蒸馏方法中的小数据陷阱(small data pitfall):当在小规模数据集如CIFAR-100(5000张训练图像)上进行评估时,在这些数据集上精心设计的KD方法很容易超过原始的KD。然而,当在大规模数据集上进行评估时,如ImageNet-1K(一百万张训练图像),vanilla KD取得了与其他方法相同甚至更好的结果。
为了解决这个问题,作者首先通过在小规模数据集上训练更长的迭代次数来补偿有限的数据。尽管迭代次数更多,精心设计的KD方法仍然大大优于vanilla KD。由此可见,大规模数据集是vanilla KD实现其最优效果所必需的。
作者进一步研究了决定KD性能的关键因素,并仔细研究了两个关键因素,即训练策略和模型容量。对于不同的训练策略,作者有以下观察: (i)通过评估在小数据集上不同训练策略下都表现良好的精心设计的方法在大型数据集如ImageNet上的表现,通过更强的数据增强和更长的迭代次数,vanilla KD和其他精心设计的KD方法之间的差距逐渐减少。(ii)通过实验还表明,logits-based方法在泛化性方面优于hint-based方法。
对于模型容量,作者比较了不同尺度的teacher-student pairs,例如Res34-Res18和Res152-Res50。结果表明,vanilla KD最终达到了和精心设计的方法相当的性能,这表明模型容量对蒸馏的影响实际上很小。
本文的创新点
- 本文指出了现有的各种精心设计的蒸馏方法存在的小数据陷阱问题,从而导致vanilla KD的表现被低估了。当在更大的数据集上以及使用更强的训练策略时,各类精心设计的蒸馏方法可能变成次优的了。
- 改进backbone在ImageNet上的性能,可以显著提升下游任务如目标检测、实例分割的性能。
小数据陷阱:vanilla KD在小数据集上的表现受限
作者首先研究了使用小模型和小数据集对不同KD方法的影响。为了提供一个全面的分析,作者比较了vanilla KD与两种最先进的logits-based KD方法,即DKD和DIST。这里的重点是确定vanilla KD较差的表现是否可以归因于小学生模型或小规模数据集。
lmpact of limited model capacity
作者首先用两组常用的师生模型进行实验:Res34-Res18和Res50-MobileNetV2,结果如表1a所示
之前的训练策略是采用SGD优化器,训练90个epoch,对应表中的"Original"。现在采用一个更强的训练策略,用AdamW优化器训练300个epoch,对应表中的"Improve"。从结果可以看出,在原始的训练策略下,DKD和DIST相比vanilla KD具有明显的优势,当采用更强的训练策略时,三者的表现都有所提升,但值得注意的是两者和vanilla KD之间的差距缩小了,这表明vanilla KD的性能不足可以归因于不充分的训练,而不是小的模型容量。
Impact of small dataset scale
为了研究小规模数据集对vanilla KD性能的影响,作者在CIFAR-100上进行了实验,同样采用了更强的训练策略,epoch从240延长到2400并且引入了更强的数据增强,如表1b所示。尽管三者的性能都有所增加,但vanilla KD和另两者之间仍然存在明显的差距,这表明vanilla KD的性能不足不仅仅归因于训练的不足,还有小规模数据的原因。
Evaluation of the power of vanilla KD on large-scale dataset
Experimental setup
Datasets. ImageNet-1K
Models. 主要采用Res50作为学生模型,BEiTv2-L作为教师模型,因为当输入为224x224时,后者是目前开源精度最高的模型。此外,还采用了ViT和ConvNeXtV2作为学生模型,ResNet和ConvNeXt作为教师模型。
Training Strategy. 主要采用了两者更复杂的训练策略,如表2所示,其中A1略强于A2。
Baseline distillation methods. 为了进行充分的对比,采用了一些最近新提出的蒸馏方法作为baseline,包括logits-based方法vanilla KD、DKD、DIST和hint-based方法CC、RKD、CRD、ReviewKD
3.2 Logits-based methods consistently outperform hint-based methods
本节对logits-based方法和hint-based方法进行对比分析,采用Res50作为学生模型,Res152和BEiTv2-L作为教师模型,学生模型分别在300 epoch和600 epoch两种配置下蒸馏。结果如表3所示。为了更深入的理解学生模型在ImageNet-1K数据集之外的泛化性,还在ImageNet-Real和ImageNet-V2 matched frequency两个数据集上进行了评估。
在使用A2策略训练了300个epoch后,所有hint-based方法的效果都不如logits-based方法,尽管使用了更强的训练策略A1和更长的迭代次数后,两者之间的差异仍然十分明显。此外,hint-based方法需要更长的训练时间,表明了它在有效性和效率两方面的局限性。
Discussion. 实验结果还表明logits-based方法在泛化性方面也优于hint-based方法,作者推测这种差异是由于教师和学生模型在应对复杂分布时能力的差异导致的。hint-based方法强行让学生模型模拟教师模型的中间特征表示阻碍了其获得满意的结果。此外,当遇到异构的教师-学生架构时,由于表示能力的差异,hint-based方法可能会遇到困难,这阻碍了特征对齐的过程。
3.3 Vanilla KD preserves strength with increasing teacher capacity
表3还展示了使用异构的BEiTv2-L作为教师模型另外两种logits-based方法和vanilla KD之间的比较结果。该教师模型在开源模型的ImageNet-1K上达到了最高的精度。
在所有评估设置中,三种logits-based方法都展现了相似的性能。当Res152作为教师蒸馏了300个epoch后,vanilla KD在ImageNet-1K验证集上得到了80.55%的top-1精度,仅比表现最好的学生模型DIST差0.06%。当训练延长到600个epoch后,vanilla KD蒸馏的学生模型达到了81.33%的精度。当教师模型改为BEiT-L,vanilla KD达到了80.89%的top-1精度,超过了DKD和DIST。即使在有distribution shift的情况下,vanilla KD在ImageNet-Real和ImageNet-V2 matched frequency数据集上也展现出了和其他方法相当的性能表现。
这些结果表明了vanilla KD在以前的研究中被低估了,因为它们是在小规模数据集上设计和评估的。当接受强大的同构和异构教师模型的蒸馏时,vanilla KD取得了和其他sota方法相当的表现,同时仍然保持了其简单性。