EfficientFormer
transformer引入MHSA,有长时依赖而且并行化简单,但高延迟,因为参数量大、token长度平方级计算复杂性增长、非折叠的均一化层、缺少编译器级优化。
这篇文章就是想要把延迟降到和CNN差不多,但不影响准确率。
主要思想
分析
为什么在边缘计算延迟高
1、Patch embedding with large kernel and stride .
对比是DeiT、PollFormer和LeViT-256;许多编译器无法支持大核卷积,而且大核卷积也不能通过现有地算法加速。
优化:用多个3*3卷积代替重叠块嵌入。
2、Consistent feature dimension is important for the choice of token mixer.
对比PoolFormer和LeViT,发现Reshape操作高延迟;LeViT主要用4D向量,送入MHSA时需要reshape成3D,池化时需要4D,因此reshape操作还挺多;
对比DeiT和LeViT,MHSA并没有带来更多的开销;
优化:维度一致的网络结构,但是包含4D特征图和3D MHSA,消除reshape操作。
3、CONV-BN比LN-proj.更快,而且精度下降不多
对比DeiT和PoolFormer,BN操作可以折叠到前面的卷积中获得推理加速,而动态归一需要在inference阶段手机运行统计信息;
4、非线性延迟来自于硬件和编译器
在手机上,HardSwish最慢,GeLU比ReLU慢一点,这个需要根据硬件不同而改变,这次实验中用的GeLU。
网络结构
这个网络结构与MobileNet完全不同,可以看到没有深度分离卷积操作。
- stage将具有相同空间大小特征的几个MetaBlock堆栈;
- patch embedding用步长为2的卷积层代替;
- 4D->3D用的还是reshape;
-
提出了一个超网概念MP,MP里面有MB4D、MB3D和残差I,就是能用MP里面的模块代替目前的网络结构中的模块;
-
本论文规定3D只在最后两个stage用,是因为MHSA对于token长度平方级增长以及MHSA主要提取全局特征,前面部分提取局部特征;
-
对Ci和Ni有不同的选择,为了找到合适的数值,设计了一个简单的搜索算法:
1、用Gumble Softmax sampling计算网络中MP的重要值,计算得到网络权重(MP)和结构参数α,α评估每个块的重要性,e保证exploration是Gumble分布求出的服从均匀分布的独立样本ei,t是temperature超参数,n是不同块的类型,前面这个部分类似于求概率其实就是每个块的重要值,α是x
得到的概率密度(这个没看懂)
2、建表收集4D和3D不同宽度下的延迟;
3、利用表对第一步获得的超级网络进行网络瘦身,通过延迟评估。
-
逐渐瘦身:不像以前直接使用有最大α的块,因为内存可能会很大,而且MP分支太多。
- 分别定义stage1、2和stage3,4的重要性;类似地,每个stage的重要性可以通过总结阶段内所有MP得分来获得;(我的理解是I是最基本的,所以用比来做得分)
-
根据重要性得分,定义行为空间包含三个方向:选择残差作为最小的MP、移除3D、减少最不重要stage宽度,通过查表计算每个操作的结果延迟,并评估每个操作的准确性下降;
-
最后根据每延迟下降的准确率选择相应的操作
实验结果
image classification:更快更好
主干网络:APbox和APmask是什么?
box AP即box average precision,mask AP即mask average precision;
从MetaFormer得到启发,但比poolformer好;
总结
缺点
-
在不同平台上差异会很大,例如不同平台支持或对GeLU和HardSwish操作有优化;
-
提出的延迟驱动网络瘦身完全是因为搜索代价考量,暴力算法可能会找到更好的;
贡献
- 分离4D和3D,减少reshape;
- 延迟驱动网络瘦身方法;