1. Contribution
本文提出Residual Multi-Layer Perceptrons (ResMLP)
We propose Residual Multi-Layer Perceptrons (ResMLP): a purely multi-layer perceptron (MLP) based architecture for image classification.
(i) a linear layer in which image patches interact, independently and identi- cally across channels,
and (ii) a two-layer feed-forward network in which channels interact independently per patch.
如图1是ResMLP architecture。简要流程如下,分为网络的输入,2个残差结构,分别是linear layer以及MLP with a single hidden layer,最后是一个average pool layer和a linear classifier:
- it takes flattened patches as input, projects them with a linear layer, and sequentially updates them in turn with two residual operations:
- (i) a simple linear layer that provides interaction between the patches, which is applied to all channels independently;
- (ii) an MLP with a single hidden layer, which is independently applied to all patches.
- At the end of the network, the patches are average pooled, and fed to a linear classifier.
2. Summary
- ResMLP可以在只是用ImageNet-1k训练的情况下,得到一个精度和时间复杂度平衡。
Despite their simplicity, Residual Multi-Layer Perceptrons can reach surprisingly good accuracy/complexity trade-offs with ImageNet-1k training only1, without requiring normalization based on batch or channel statistics;
- 通过蒸馏方法可以获得收益。
These models benefit significantly from distillation methods.
- linear层的设计,有助于观察到网络通过层与层之间学习何种空间交互信息。
thank to its design where patch embeddings simply “communicate” through a linear layer, we can make observations on what kind of spatial interaction the network learns across layers.
3. Methods
3.1 The overall ResMLP architecture
a linear sbulayer followed by a feedforward sublayer.
输入NxN的patch,N为16, 过一个linear layer,这些patches最终送入一个线性层来得到NxNxd的 patch embeddings。
首先将NxNxC embedding信息送入ResMLP层中,来得到NxNxd的output embedding。 224 x 224 x 3 --> 16x16x (14 x 14 x 3) --> 16 x 16 x d
最终NxNxd的vector经过average pool,得到d维的vector,再融入linear classifier [d, C],得到网络的预测label,使用交叉熵loss 训练网络。
3.2 The Residual Multi-Perceptron Layer
作者使用Affine transformation替换了Transformer layer中的Layer Normalization,具体公式如下:
其中,