视频讲解
【CCF-A】针对2D图像的轴向稀疏连接MLP方案|中国科学技术大学
Sparse MLP for Image Recognition: Is Self-Attention Really Necessary?(AAAI 2022)
论文地址: https://ojs.aaai.org/index.php/AAAI/article/view/20133
1、摘要总结
本文探索了Transformer中的核心自注意力模块是否是实现图像识别优异性能的关键因素。
本文提出一种新颖的稀疏MLP模块(sMLP)。sMLP模块在2D图像标记中沿轴向方向应用1D MLP,并在行或列之间共享参数。这种方法通过稀疏连接和权重共享显著降低了模型参数的数量和计算复杂度,从而避免了常见的过拟合问题。
实验结果表明,在计算机视觉领域,自注意力机制并非取得优异成绩的唯一途径。
2、研究背景
ViT将图像切分为补丁(patch),并将这些补丁序列化为线性的嵌入(embedding)输入给视觉Transformer。尽管ViT在大规模预训练后表现出色,但其计算复杂度较高,不适合高分辨率输入。
无卷积视觉Transformer实际上推动了两种观点。首先,全局依赖建模很重要。不仅如此,它甚至可以完全取代过去通过卷积嵌入到模型中的局部依赖建模。其次,自我注意力很重要。
然而,既然局部性在自然图像中始终有效,为什么要通过全局自注意力模块来学习它,而不是直接将其注入网络?此外,全局自注意力的计算复杂度是输入 token 数量的二次方。因此,这种金字塔结构网络结构不利于高分辨率输入,对图像质量不友好。这些被认为是视觉 Transformers 的缺点,因为高分辨率输入和金字塔结构已被广泛认可以提高图像识别精度。此外,针对自我注意力也有一些看法,MLP-Mixer (Tolstikhin et al 2021) 认识到建模全局依赖关系的重要性,但它采用 MLP 块而不是自注意力模块来实现它,MLP-Mixer 继承了 ViT 的所有缺点,由于参数数量过多,它很容易过拟合。
3、现有工作的不足
现有的MLP基础模型,如gMLP、ResMLP和EA等,在聚合空间信息时参数和计算复杂度与输入图像大小呈二次关系。这些方法与最先进的模型相比仍存在一定的准确度差距。主要原因在于忽视了局部性偏置、缺少金字塔结构以及由于参数过多导致的过拟合问题。
4、创新点与解决问题的方法
与现有的基于 MLP 的方法相比,我们的方法尝试利用基于金字塔结构的局部偏差和全局依赖性。由于 MLP 全局依赖建模的二次计算复杂度对金字塔结构不友好,我们使用提出的稀疏 MLP 模块来支持基于金字塔结构的全局依赖建模。
4.1 创新点
- 稀疏MLP(sMLP)模块:对于2D图像tokens,sMLP模块沿轴向应用1D MLP,并在行或列间共享参数。这种方法能够减少参数数量和计算复杂度,同时保持模型性能。参数共享:sMLP模块在行或列之间共享参数,进一步降低了模型复杂度。
- 性能优势:在ImageNet-1K数据集上,sMLPNet仅使用24M参数时达到了81.9%的Top-1准确率,优于大多数CNN和视觉Transformer;当扩展到66M参数时,准确率达到83.4%,与最先进的Swin Transformer相当。
4.2 Sparse MLP (sMLP) 的设计目标
- 减少参数数量以避免过拟合,尤其是在像 ImageNet-1K 这样的中等规模数据集上训练时尤为重要。
- 减少计算复杂度,特别是在令牌数量较大的情况下,以便于实现多级处理的金字塔结构。
图 2(b) 展示了标记混合模块。在该模块中,使用核大小为 3x3 的深度卷积(DWConv),从而利用了局部性偏差。事实上,在将信道处理从空间处理中分解出来后,DWConv 就成了探索局部性的一个非常自然的选择。这一操作的效率很高,因为它只包含很少的参数,在推理过程中只需要很少的 FLOPs。我们还尝试利用所提出的 sMLP 模块对全局依赖性进行建模。与原始 MLP 模块相比,sMLP 的稀疏性和权重共享性使其不易过度拟合。
Figure 3: 所提出的 sMLP 块的结构。它由三个分支组成:其中两个分别负责沿水平和垂直方向混合信息,另一个路径是恒等映射。三个分支的输出被连接起来并通过逐点卷积处理以获得最终输出。
图 1:所提出的稀疏 MLP 通过稀疏连接和权重共享降低了 MLP 的计算复杂度。在 MLP (a) 中,深橙色标记与单个 MLP 层中的所有其他标记交互。相比之下,在一个 sMLP 层 (b) 中,深橙色标记仅与用浅橙色标记的水平和垂直标记交互。当 sMLP 执行两次时,可以实现与所有其他白色标记的交互。
4.3 Sparse MLP(sMLP) 的实现
在 sMLP 中,一个令牌不是与所有其他令牌交互,而是只直接与同一行或同一列的令牌交互。此外,所有的行和列可以分别共享相同的投影权重。这样可以显著减少参数数量和计算复杂度。
5、实验
5.1 对比实验
6、应用场景
- 图像分类:在ImageNet-1K数据集上的图像分类任务中展现出优秀的性能。
- 计算机视觉任务:可以应用于广泛的计算机视觉任务中,如目标检测、分割等。
7、局限性
尽管sMLPNet在图像分类任务上表现出色,但由于全连接层的固定性质,MLP-like模型难以处理任意分辨率的输入图像。这意味着这类模型难以用于物体检测和语义分割等下游任务。研究团队计划在未来工作中探索构建更通用的无注意力网络的可能性。
8、结论
本文中提出的 sMLP 块具有稀疏连接和权重共享的特点。通过沿轴向分别聚合信息,sMLP 避免了传统 MLP 的二次模型大小和二次计算复杂度。实验结果表明,这极大地提高了类似 MLP 的视觉模型的性能边界。
9、即插即用模块代码共享
import torch
from torch import nn
# sMLPBlock.py
# --------------------------------------------------------
# 论文: Sparse MLP for Image Recognition: Is Self-Attention Really Necessary? (AAAI 2022)
# 论文地址: https://ojs.aaai.org/index.php/AAAI/article/view/20133
# 抖音、B站、小红书、CSDN 布尔大学士
# ------
# 定义sMLPBlock类,继承自nn.Module
class sMLPBlock(nn.Module):
# 初始化函数,定义网络结构
def __init__(self, h=224, w=224, c=3): # 输入参数为图像的高度h、宽度w以及通道数c
super().__init__() # 调用父类构造器
self.proj_h = nn.Linear(h, h) # 定义沿高度方向的线性变换层
self.proj_w = nn.Linear(w, w) # 定义沿宽度方向的线性变换层
self.fuse = nn.Linear(3 * c, c) # 定义融合层,用于融合来自不同方向的信息
# 前向传播函数
def forward(self, x):
# 沿高度方向进行线性变换,并调整维度顺序
# [B,C,H,W] ---》 [B,C,W,H]
# 因为 nn.Linear 默认对最后一个维度进行操作,所以这里先线性层后调换位置【常见套路】
x_h = self.proj_h(x.permute(0, 1, 3, 2)).permute(0, 1, 3, 2) # torch.Size([50, 3, 224, 224])
# 沿宽度方向进行线性变换
x_w = self.proj_w(x) # torch.Size([50, 3, 224, 224])
# 保留原始输入作为残差连接
x_id = x # torch.Size([50, 3, 224, 224])
# 在通道维度上合并不同的特征
x_fuse = torch.cat([x_h, x_w, x_id], dim=1) # torch.Size([50, 9, 224, 224])
# 融合信息并调整维度顺序
x_fuse_Total = x_fuse.permute(0, 2, 3, 1) # torch.Size([50, 224, 224, 9])
# 因为 nn.Linear 默认对最后一个维度进行操作,所以这里先线性层后调换位置【常见套路】
out = self.fuse(x_fuse_Total).permute(0, 3, 1, 2) # torch.Size([50, 3, 224, 224])
return out # 返回处理后的输出张量
# 测试代码
if __name__ == '__main__':
input = torch.randn(50, 3, 224, 224) # 随机生成一批大小为50x3x224x224的输入数据
print(input.shape) # torch.Size([50, 3, 224, 224])
# 实例化sMLPBlock对象
smlp = sMLPBlock(h=224, w=224)
# 将输入数据传递给sMLPBlock进行前向传播
out = smlp(input)
print(out.shape) # torch.Size([50, 3, 224, 224])
print("抖音、B站、小红书、CSDN同号") # 输出作者的社交平台账号信息
print("布尔大学士 提醒您:代码无误~~~~") # 输出作者的提醒信息