ResMlp

ResMLP是一种受到ViT启发的深度学习模型,但摒弃了注意力机制,仅使用线性层和残差连接。模型结构简单,不依赖批或通道的标准化,具有稳定性和可解释性。实验表明,ResMLP在多个数据集上表现出良好的性能,尤其是在无自注意力的情况下,仍能有效学习图像patch间的交互。
摘要由CSDN通过智能技术生成

该架构极为简单:它采用展平后的N*N个图像 patch 作为输入,他们直接相互独立,通过线性层对其进行映射为N^{2}d维嵌入特征,然后采用两个残差操作对投影特征进行更新:(i)一个简单的线性 patch 交互层,独立用于所有通道;(ii)带有单一隐藏层的 MLP,独立用于所有 patch。在网络的末端,这些 patch 被平均池化,进而馈入线性分类器。然后将输出的 N^{2}d维嵌入特征进行平均得到d维图像表达,最后将图像表达送入线性分类层预测图像对应标签,训练使用交叉熵损失

该架构是受 ViT 的启发,但更加简单:不采用任何形式的注意力机制,仅仅包含线性层与 GELU 非线性激活函数。该体系架构比 Transformer 的训练要稳定,不需要特定 batch 或者跨通道的标准化(如 Batch-Norm、 GroupNorm 或 LayerNorm)。训练过程基本延续了 DeiT 与 CaiT 的训练方式。

由于 ResMLP 的线性特性,模型中的 patch 交互可以很容易地进行可视化、可解释。尽管第一层学习到的交互模式与小型卷积滤波器非常类似,研究者在更深层观察到 patch 间更微妙的交互作用,这些包括某些形式的轴向滤波器(axial filters)以及网络早期长期交互。

架构方法

ResMLP 的具体架构如下图 1 所示,采用了路径展平(flattening)结构:

整体流程

ResMLP 以 N×N 非重叠 patch 组合作为输入,其中 N 通常为 16。然后,这些非重叠 patch 独立地通过一个线性层以形成 一组N^2 的 d 维嵌入。接着,生成的 N^2  d 维嵌入被馈入到一个残差 MLP 层序列中以生成 N^2 个 d 维输出嵌入。这些输出嵌入又被平均为一个表征图像的 d 维向量,这个 d 维向量被馈入到线性分类器中以预测与图像相关的标签。训练中使用到了交叉熵损失。

我们的模型,如图1所示,是受ViT模型的启发,采用了路径扁平化结构。我们着手进行彻底的简化。我们的模型用ResMLP表示,以N x N个不重叠补丁构成的网格作为输入,其中N通常等于16。然后,这些补丁独立地穿过一个线性层,形成一组N2 d维嵌入。所得到的N2嵌入集合被输入到一个残差多层感知器层序列中,以产生一组N2 d维输出嵌入。然后,这些输出嵌入被平均为一个d维向量来表示图像,这个d维向量被提供给一个线性分类器来预测与图像相关的标签。训练使用交叉熵损失剩余的多感知器层。

残差多感知机层

网络序列中的所有层具有相同的结构:线性子层 + 前馈子层。类似于 Transformer 层,每个子层与跳远连接(skip-connection)并行。研究者没有使用层归一化(LayerNormalization),这是因为当使用公式(1)中的 Affine 转换时,即使没有层归一化,训练也是稳定的。

其中,表示可学习向量。需要注意的是:该层推理无耗时,因其参数可与前接线性层合并。Aff独立的作用于X的每一列,尽管与BatchNorm、LayerNorm非常类似,但该操作不依赖任何batch统计;它与近期提出的LayerScale非常接近,但LayerScale没有偏置项。

研究者针对每个残差块都使用了两次 Affine 转换。作为预归一化,Aff 替代了Layernormalization,并不再使用通道级统计(channel-wise statistics)。作为残差块的后处理,Aff 实现了层扩展(LayerScale),因而可以在后归一化时采用与 [50] 中相同的小值初始化。这两种转换在推理时均集成至线性层。

此外,研究者在前馈子层中采用与 Transformer 中相同的结构,并且只使用 GELU 函数替代 ReLU 非线性。

与 Transformer 层的主要区别在于,研究者使用以下公式(2)中定义的线性交互替代自注意力:所提多层感知机层将维输入N^{2}个patch的d维输入特征堆叠为d*N^{2}矩阵X,输出维输出特征,计算公式如下:

其中,A,B,C表示主要的学习参数。参数矩阵A的维度为N^{2}*N^{2},用于混合所有位置的信息,而前馈层则作用于每个位置。因此,中间激活矩阵Z具有与矩阵X、Y相同的维度。最后参数矩阵B和C的维度类似Transformer层,即4d*d,d*4d。

方法

与 ViT 的关联

ResMLP 是 ViT 模型的大幅度简化,但具有以下几个不同点:

ResMLP 没有采用任何自注意力块,使用的是非线性(non-linearity)的线性 patch 交互层;ResMLP 没有采用额外的「类(class)」token,相反只使用了平均池化;

ResMLP 没有采用任何形式的位置嵌入,不需要的原因是 patch 之间的线性通信模块考虑到了 patch 位置;

ResMLP 没有采用预层归一化,相反使用了简单的可学习 affine 转换,从而避免了任何形式的批和通道级统计。

实验结果

Datasets 训练数据为ImageNet,除了ImageNet本身的验证集外,我们还在ImageNet-real、ImageNet-v2数据集上进行了验证测试。

Training paradigms 训练方式考虑了以下两种形式:

  • 监督学习:采用softmax分类+交叉熵损失训练,本文主要聚焦于此;

  • 知识蒸馏:采用ConvNet通过知识蒸馏方式引导ResMLP训练。

Hyper-parameter setting 在监督学习中,我们采用Lamb羽化期,学习率为,权值衰减0.2。超参设置于DeiT类似,知识蒸馏时的老师模型为RegNety-16GF。

首先,研究者将 ResMLP 与 Transformer、convnet 在监督学习框架下进行了比较,如下表 1 所示,ResMLP 取得了相对不错的 Top-1 准确率。

  • 尽管ResMLP在精度、FLOPs以及吞吐量的均衡方面不如ConvNet、Transformer,但其性能仍非常优异;

  • 事实上,这里所对比的ConvNet经过了多年的研究与精心优化才达到了如此好的性能;而本文所提方法只是最简单的适配,未经过多的优化。

其次,利用知识蒸馏提高模型的收敛性,结果如下表 2 所示。

与 DeiT 模型类似,ResMLP 可以从 convnet 蒸馏中显著获益。表中结果表明:前馈网络仍存在过拟合问题。额外的正则技术与蒸馏可以进一步提升模型的性能。

知识蒸馏:https://zhuanlan.zhihu.com/p/102038521

实验还评估了 ResMLP 在迁移学习方面的性能。下表 3 展示了不同网络架构在不同图像基准上的性能表现,数据集采用了 CIFAR-10、CIFAR100、Flowers-1022、 Stanford Cars 以及 iNaturalist 。

权重稀疏性测量也是研究者的关注点之一。下图 2 的 ResMLP-24 线性层的可视化结果表明线性通信层是稀疏的,并在下图 3 中进行了更详细的定量分析。结果表明,所有三个矩阵都是稀疏的,实现 patch 通信的层明显更稀疏。

最后,研究者探讨了 MLP 的过拟合控制,下图 4 控制实验中探索了泛化问题。

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值