论文见:
https://arxiv.org/abs/2211.12905
MindSpore代码:
https://gitee.com/mindspore/models/tree/master/research/cv/ghostnetv2
01
引言
Transformer已经成为了语言和视觉任务中常用的基础架构之一。然而,由于Transformer结构高计算开销的影响,其在端侧等资源受限设备中的应用依然面临很大的挑战。我们对Transformer结构中的标准化层和注意力机制两个模块的优化策略进行了深入探索,从而构建一个高效的Transformer结构。其中,LayerNorm作为Transformer结构中常用的标准化层,但模型推理时仍需计算数据的统计值,导致了推理的低效。我们提出了渐进式的LayerNorm替换策略,并对标准的BatchNorm进行了改进以更好地取代LayerNorm层。同时,我们采用了一种简单高效的线性注意力模块(Simplified Linear Attention),来获得更强的模型性能。我们将这两种策略的结合简称为SLAB。我们在图像分类、目标检测以及语言任务上都进行了大量的实验,获得了很好的效果。例如,我们的SLAB-Swin-S在ImageNet1k数据集上获得了83.6%的分类精度,相对Flatten-Swin-S在精度提升0.1%的情况下,时延减少了2.4ms。
02
方法
2.1 渐进式重参数化BatchNorm
LN作为Transformer中常用的标准化层结构,由于其在训练和推理两阶段均存在均值和方差的计算,影响了Transformer的执行速度。与之相对,BN仅在训练阶段存在均值和方差的计算,且在推理阶段可与相邻的线性层融合,可以去除标准化层对模型推理速度的影响。但是,在Transformer结构中将LN简单替换为BN训练会导致模型精度下降以及训练崩溃等问题。为解决这个问题,我们对BN进行了优化,并提出了渐进式重参数化BatchNorm策略。
首先,重参数化BatchNorm的定义如下:
式中,𝜂是一个可学习参数。其中,RepBN可以通过调节BN的权值和偏移量,是特定层BN操作被跳过;当𝜂为0时,RepBN等效为纯BN结构。同时,RepBN能重参数化为BN的表现形式,并实现与相邻线性层的融合。
其次,为增强BN在Transformer结构中的训练稳定性,我们引入了渐进式替换策略。其表示形式如下:
式中,𝛾是一个超参数,用于控制LN和RepBN的输出比例。在训练开始阶段,𝛾一般设置为1,此时LN在模型中发挥主导作用;在训练结束阶段,𝛾将衰减至0,此时模型将转变为纯BN组成的结构。在实际应用中,我们采用了简单的线性替换策略,𝛾的值输出如下:
其中,𝑇为训练中包含LN的总训练步数,T𝑐𝑢𝑟为模型当前的训练步数。相对于其他衰减策略,我们发现线性策略更为简单且高效。因此,后续实验中我们均采用了线性衰减的策略。
2.2 简化线性注意力
Attention是Transformer网络中重要的模块之一。为进一步压缩模型计算量,我们引入了线性注意力模块。在该模块中,我们仅使用了硬件亲和的ReLU算子作为相似度函数,并增加了一个深度可分离模块增强局部特征提取。该简单线性注意力模块(simplified linear attention, SLA)形式如下:
式中,DWC表示深度可分离卷积。
03
实验结果
3.1 分类任务
我们在ImageNet1k数据集上进行了实验,实验结果证明在多个backbone上,我们的PRepBN均获得了与LN相当甚至更好的性能。从实验结果看,相当基于LN的模型,PRepBN模型的分类精度有0.1%~1.4%的提升。而基于我们SLAB的模型,能在精度与Flatten Transformer相当的情况下,减少模型的推理的时延。
3.2 检测任务
此外,我们验证了不同backbone在COCO数据集上的效果。从实验结果可以看出,我们的方法实现了与原Backbone模型相当的性能,但拥有更低的模型推理时延。
3.3 语言任务
我们基于Adaptive inputs方法在Wikitext-103数据集上评测了PRepBN在语言任务的能力。同时,我们也将PRepBN应用在了LlaMA-350M模型中,并评测了模型在下游任务的性能。从实验结果可以看出,我们的PRepBN方法在语言任务上也表现出了不错的性能,精度无损的情况下将LLaMA-350M速度从44tokens/s提升到了50.4tokens/s。
04
总结
我们对Transformer结构中的标准化层和注意力机制两个模块的优化策略进行了深入探索,提出了渐进式的LayerNorm替换策略,同时采用一种简单高效的线性注意力模块,来获得更加高效的Transformer模型架构。这个方法在图像分类、目标检测以及语言任务上进行了大量的实验验证,在精度无损的情况下,大幅提升Transformer的推理效率。