前言
Transformer 作为自然语言处理领域的主流方法,近年被越来越多研究应用于机器视觉任务且能实现SOTA 结果。ViT(Vision Transformer)是Transformer在计算机视觉领域成功的应用,也是目前图片分类效果最好的模型,超越了曾经最好的卷积神经网络(CNN)。高性能的ViT可以在JFT-300M这样的巨大数据集上达到理想的分类结果,但当数据集不够大时,其效果是不如CNN的,限制了其应用。2021年Facebook AI与索邦大学联合开发了一项新技术:Data-efficient image Transformers (DeiT),是一种新的高效图像分类算法,用Transformer替代卷积,只需要较少的数据和算力就可以在ImageNet数据集上达到与顶尖的卷积网络(CNN)媲美的图像分类结果(GitHub源码见[4],本文中一,三,四部分所插入图片引用自参考资料[1])。
一、为什么要推出DeiT模型
对于普通开发者,要想利用ViT获得理想的泛化能力,需要庞大的计算资源和数以亿计的图像进行预训练,成本昂贵,而Facebook的DeiT模型(8600万参数)仅用3天时间在一台 8-GPU 的服务器上进行训练,即可在ImageNet数据集上取得比ViT性能更为优秀的效果。在此基础上,加入蒸馏(Distillation)操作,能使性能进一步提升。如下图所示,横轴代表ImageNet数据集上的吞吐量(即throughout, 一秒内可以处理的最大输入实例数),⚗表示加上了特定蒸馏操作的DeiT模型。
二、相关知识
1. 知识蒸馏的作用机制
1. 原理
知识蒸馏是在深度学习,模型压缩领域非常重要的一个方法。其本质是从已训练完善的教师模型,“蒸馏”出“知识”用于学生模型的训练,希望轻量级的学生模型能够达到和教师模型相同的表现。
对于一个模型,我们关注其经过调试能够正确表达输入到输出的映射关系,这种映射关系通常是可以量化为概率的。Hinton等人认为对于分类任务,需要蒸馏出来的知识,就存在于大模型输出错误预测的相对概率。
通常在分类目标中的hard target是0/1编码的,通过熵看携带的信息量比连续的概率输出要少。比如在进行图片分类时,一辆汽车只有很小的概率被误认为是垃圾车,但这种错误的可能性仍然比把它误认为胡萝卜高出许多倍,这种类别间的概率分布对于学生模型的进一步学习很有用,它定义了一种丰富的数据相似结构。将复杂大模型的泛化能力转化到小模型上的一个明显方法就是利用大模型产生的类别概率(一般为softmax输出)作为训练小模型的软目标(soft target),软目标能够提供比hard target更多的信息。
2. 流程
1)带温度的softmax函数通过这个带温度参数的softmax函数,类别间概率分布中存在的相似性信息通过升温被放大,这样才能对损失函数产生更大的影响。也即汽车被误认为垃圾车的概率在升温之后其相对值会变大。
2)损失函数 最终损失函数是软目标函数和硬目标函数的结合:知识蒸馏的过程,简单而言,第一,利用大规模数据训练一个教师网络;第二,利用数据和教师网络指导训练一个学生网络。算法的整体流程如图所示:
2. 关于ViT
受NLP领域中Transformer成功应用的启发,ViT(Vision Transformer)尝试将标准的Transformer结构直接应用于图像,并对整个图像分类流程进行最少的修改。ViT会将整幅图像拆分成小图像块,然后把这些小图像块的线性嵌入序列作为Transformer的输入送入网络,之后使用监督学习的方式进行图像分类的训练。ViT算法结构示意图如下图所示,具体地,先将三维图像展平分块,并对图像块进行embedding得到想要的维度,随后class token与input token合并为向量一起输入Transformer结构中进行特征提取,最后进入MLP层,class token最后的输出结果用来预测类别(下图引用自参考资料[3])。ViT在不同规模数据集上的实验表现为:
在ImageNet等小数据集上,同等规模的ViT训练精度比不上ResNet;
随着训练数据集规模的增加,ViT的性能将开始攀升,并最终超过ResNet。
三、DeiT模型介绍
ViT在大数据集(JFT-300M)上的表现可以达到或超越当前的SOTA水平,但是在小数据集上表现不理想,限制了其应用范围。DeiT的网络结构跟ViT基本一致,主要的区别在于添加了一个蒸馏token以及训练策略不同。如上图,在ViT架构基础上引入的蒸馏token,参与了整体信息的交互过程(在slef-attention layer中与class token和patch token不断交互)。蒸馏token的地位与class token相同,唯一的区别在于,class token的目标是跟真实的label一致,而蒸馏token是要跟teacher模型预测的label一致。蒸馏分为两种,一种是软蒸馏(soft distillation),一种是硬蒸馏(hard-label distillation)。软蒸馏 以交叉熵为分类损失,KL散度为蒸馏损失,用教师网络的softmax输出为标签,也即是常规蒸馏方式:硬蒸馏 以交叉熵为蒸馏损失,以教师网络的硬输出为标签:蒸馏token 见上图模型结构,在输入embedding处添加一个蒸馏token,蒸馏token的使用与class token类似,它通过slef-attention与其他embedding交互。通过训练迭代后,二者的余弦相似性会逐渐增大(但始终小于1),这也表明蒸馏embedding允许我们从teacher的输出中学习,同时与class embedding保持互补。
联合分类器 经过测试发现,class token和蒸馏token是朝着不同的方向收敛的,对各个layer的这两个token计算余弦相似度,平均值只有0.06,但是其余弦相似度随着网络传输会越来越大,在最后一层达到0.93。那么在测试时我们将分别拥有两个token的输出向量,考虑把二者的softmax输出相加,进行预测。
四、实验
DeiT相比于ViT,在结构上添加了蒸馏token。为了便于比较,构建了与ViT类似的三种不同规模的对照组DeiT,其详细参数如下表所示:利用CNN做教师网络 蒸馏出的学生模型在准确率和吞吐量之间的权衡胜过其教师网络,有趣的是,卷积网络作为教师网络的结果要比用Transformer作为教师网络的结果更佳。不同蒸馏策略的比较 从实验结果得出,硬蒸馏显著优于软蒸馏,且class+distillation联合token的效果优于单一token:epoch的个数 增加epoch的数量会显著提升蒸馏训练效果,但是一定数量后也会趋于饱和:下游任务的迁移 已知DeiT在ImageNet上表现良好,为了衡量它的泛化能力,在不同数据集上进行迁移学习。下表将DeiT迁移学习结果与ViT和EfficientNet进行了比较,DeiT与CNN模型的表现相当,这与之前在ImageNet-1k上得到的结论一致:原论文中还包含了一些关于token评估、权衡速度与精度、数据增强、训练策略、fine-tune分辨率等的讨论,不再赘述。
五、总结
我们介绍了一种用于图像transformer的模型,由于改进了训练,特别是引入蒸馏策略,因此不需要非常大量的数据进行训练。这对我们项目当前的研究启发主要有:
蒸馏模型在权衡精度和吞吐量之后的优越性甚至高于teacher模型,说明了蒸馏的有效性;
在Transformer中硬蒸馏的性能是优于软蒸馏的;
对于跨模型的蒸馏,teacher模型的选择很关键,会影响性能的提升;
DeiT模型在小数据集上通过蒸馏token的方式提升精度的思路值得参考,使得实验免于巨量的数据集和训练资源。
对于DeiT,我们的训练依赖于现有的数据增强和卷积网络的正则化策略,除了蒸馏token外,没有引入任何重要的架构更改,因此,期待对图像Transformer的进一步研究。
参考资料
[1] Touvron H, Cord M, Douze M, et al. Training data-efficient image transformers & distillation through attention[C]//International Conference on Machine Learning. PMLR, 2021: 10347-10357.
[2] Hinton G, Vinyals O, Dean J. Distilling the knowledge in a neural network[J]. arXiv preprint arXiv:1503.02531, 2015, 2(7).
[3] Dosovitskiy A, Beyer L, Kolesnikov A, et al. An image is worth 16x16 words: Transformers for image recognition at scale[J]. arXiv preprint arXiv:2010.11929, 2020.
[4] https://github.com/facebookresearch/deit