MobileViT(ICLR 2022,Apple)论文与代码解析

paper:MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer

official implementation:https://github.com/apple/ml-cvnets

third-party implementation:https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/mobilevit.py

MobileViT和MobileNetv2比,精度高了1.5个点,但延迟是人家的八倍!!!

出发点

这篇论文的出发点是结合卷积神经网络(CNNs)和自注意力机制的Vision Transformer(ViTs)的优点,以构建适用于移动设备的轻量级和低延迟的视觉transformer。传统的CNNs在处理空间局部信息上表现出色,而ViTs则擅长于学习全局表示。然而,ViTs模型通常参数量大且延迟高,难以在资源有限的移动设备上高效运行 。

本文的创新点

本文文解决了如何在移动设备上高效运行视觉任务的问题。具体来说,作者提出了一种名为MobileViT的模型,该模型通过结合CNNs和ViTs的优势,既保留了CNNs的轻量级和高效性,又引入了ViTs的全局信息处理能力。实验结果显示,MobileViT在多个任务和数据集上显著优于基于传统CNNs和ViTs的模型 。

MobileViT的优点如下

  1. 结合CNNs和ViTs的优势
    局部与全局信息处理:MobileViT通过将卷积操作与transformer block相结合,既保留了卷积神经网络(CNNs)处理局部信息的能力,又引入了vision transformer(ViTs)处理全局信息的能力。这样的结合使得MobileViT在处理复杂视觉任务时能够更好地捕捉全局和局部特征。
  2. 轻量级和高效性
    参数量与计算量:MobileViT在参数量和计算复杂度方面显著低于传统的ViTs。例如,在ImageNet-1k数据集上,MobileViT的参数量约为600万,而性能却优于相同参数量级别的模型,如MobileNetv3和DeiT。
    低延迟:由于参数量和计算量的优化,MobileViT在移动设备上具有更低的延迟,能够在资源受限的环境中高效运行。
  3. 性能提升
    准确率:实验结果显示,MobileViT在多个数据集和任务上表现优异。例如,在ImageNet-1k数据集上,MobileViT的Top-1准确率为78.4%,比MobileNetv3和DeiT分别高出3.2%和6.2%。
    通用性:MobileViT在不同任务(如图像分类、目标检测、语义分割)中的表现均优于现有的轻量级CNNs和ViTs。这表明其具有较强的通用性,适用于各种视觉任务。
  4. 简单的训练方法
    训练稳定性:MobileViT通过结合CNNs的空间归纳偏置,使其在训练过程中更稳定,不易过拟合。此外,与其他需要复杂数据增强和正则化方法的ViTs模型相比,MobileViT能够在更简单的训练设置下取得较好的性能。
  5. 模型设计创新
    MobileViT Block:MobileViT引入了一种新的MobileViT块,该块通过替换卷积中的局部处理为全局处理,实现了有效的局部和全局信息编码。这种设计使得MobileViT在保持轻量级的同时,能够学习到更好的表示。
  6. 在移动设备上的实际应用
    实际应用性能:在移动设备上,MobileViT展示了优异的性能和效率。例如,在iPhone-12上的测试显示,MobileViT在推理速度和准确率方面均优于其他模型。

方法介绍

MobileViT block的结构如图1(b)所示。具体来说,对于一个输入张量 \(\mathbf{X}\in \mathbb{R}^{H\times W\times C}\),MobileViT首先用一个标准 \(n\times n\) 卷积和一个 \(1\times 1\) 卷积得到 \(\mathbf{X}_L\in \mathbb{R}^{H\times W\times d}\)。为了让MobileViT既学习到全局表示同时又具有卷积的归纳偏置能力,作者将 \(\mathbf{X}_L\) 进行unfold得到 \(N\) 个不重叠的展平的patch \(\mathbf{X}_U\in \mathbb{P}^{N\times W\times d}\),其中 \(P=wh\),\(N=\frac{HW}{P}\) 是patch的数量,\(h\le n,w\le n\) 分别是patch的高和宽。对于每个 \(p\in\{1,...,P\}\),patch间的关系通过transformer编码得到 \(\mathbf{X}_G\in \mathbb{P}^{P\times N\times d}\)

和ViT失去了像素的空间顺序不同,MobileViT即保留了patch的顺序又保留了每个patch内像素的顺序。然后再将 \(\mathbf{X}_G\in \mathbb{P}^{P\times N\times d}\) fold回去得到 \(\mathbf{X}_F\in \mathbb{R}^{H\times W\times d}\)。\(\mathbf{X}_F\) 然后通过一个point-wise卷积映射到一个较低的 \(C\) 维度并与原始的 \(\mathbf{X}\) 进行concat。然后用另外一个 \(n\times n\) 卷积来融合concatenated的特征。注意因为 \(\mathbf{X}_U(p)\) 通过卷积编码了 \(n\times n\) 区域内的局部信息,而 \(\mathbf{X}_G(p)\) 编码了 \(P\) 个patch内 \(p\)-th位置的全局信息,因此 \(\mathbf{X}_G\) 中的每个像素都可以编码来自 \(\mathbf{X}\) 中所有像素的信息,如图4所示

 

Light-weight. MobileViT block和之前的工作一样采用了标准卷积和transformer来分别学习局部和全局表示,那为什么MobileViT可以做到轻量?作者认为之前的ViTs通过transformer学习全局信息,缺少了image-specific inductive bias,因此需要更多的容量来学习视觉表示,从而使网络deep and wide。而MobileViT同时具备学习局部信息和全局信息的能力,使得我们可以将网络设计的shallow and narrow。比如和使用 \(L=12,d=192\) 的DeiT相比,MobileViT分别在32x32、16x16、8x8的spatial level使用 \(L=\{2,4,3\},d=\{96,120,144\}\),得到MobileViT网络和DeiT相比,快了1.85x,小了2x,精度提升了1.8%。

MobileViT architecture. 作者设计了三种不同大小的MobileViT,S: small, XS: extra small, XXS: extra extra small。具体结构如图1(b)和下表所示。

 

第一层采用strided 3x3标准卷积,然后是若干MobileNet v2 block和MobileViT block。激活函数采用Swish。MobileViT block中卷积核大小采用 \(n=3\),所有spatial level的patch大小采用 \(h=w=3\)。

Multi-scale Sampler for Training Efficiency

在ViT-based模型中,学习多尺度表示的一种标准方法是进行微调,这种方法对于ViTs更合适因为对于不同大小的输入位置编码需要进行相应的插值。而MobileViT和CNN类似,不需要使用位置编码,因此可以直接进行多尺度训练。之前CNN中的多尺度训练的做法通常是在固定数量的iteration后采样一个新的分辨率,这会导致GPU利用率不足并训练较慢,因为不同的分辨率下都采用相同的batch size(根据预定义的最大分辨率确定)。为了提高训练效率,本文提出了在多尺度训练中,针对不同的分辨率采用不同的batch size,具体就是在训练前对于最大的分辨率 \((H_n,W_n)\) 确定批量大小 \(b\),在训练中随机选择一个分辨率 \((H_t,W_t)\),对应批量大小根据式 \(b_t=\frac{H_nW_nb}{H_tW_t}\) 得到。

实验结果

不是,这MobileViT和MobileNetv2比,精度高了1.5个点,但延迟是人家的八倍,这篇文章搞了个寂寞啊。

代码解析

这里是timm中的代码,模型选择了"mobilevit_s",输入大小为(1, 3, 256, 256)。核心代码在class MobileVitBlock中,我们直接来看forward函数。经过前面stem的处理后,这里输入大小为(1, 96, 32, 32),对照图1可以看到这里的区别就是没有把patch的spatial维度变为1全放到通道维度中,例如对于普通的transformer,一个高宽为hw通道为d的patch要转换成(1, hwd),而在这里转换为(1, hw, d)。

下面代码在conv_1x1处理后的输出为(1, 144, 32, 32),patch_size=2,则每个patch的面积为2x2=4,patch的数量为(32/2)x(32/2)=256,对于普通的transformer转换后的维度为(1, 144x4, 256),而这里为(1, 256, 4, 144)。self.transformer还是普通的实现,对于普通的输入(batch_size, seq_len, fea_dim),transformer是对每个batch的(seq_len, fea_dim)进行处理,而这里transformer是对每个patch内的同一位置的像素进行处理,所以把hw=4和batch_size相乘放到了一起。

def forward(self, x: torch.Tensor) -> torch.Tensor:
    shortcut = x  # (1,96,32,32)

    # Local representation
    x = self.conv_kxk(x)  # 3x3-s1 conv, (1,96,32,32)
    x = self.conv_1x1(x)  # (1,144,32,32)

    # Unfold (feature map -> patches)
    patch_h, patch_w = self.patch_size  # 2,2
    B, C, H, W = x.shape
    new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(W / patch_w) * patch_w  # 32,32
    num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w  # n_h, n_w, 16,16
    num_patches = num_patch_h * num_patch_w  # N, 256
    interpolate = False
    if new_h != H or new_w != W:
        # Note: Padding can be done, but then it needs to be handled in attention function.
        x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=False)
        interpolate = True

    # [B, C, H, W] --> [B * C * n_h, n_w, p_h, p_w]
    x = x.reshape(B * C * num_patch_h, patch_h, num_patch_w, patch_w).transpose(1, 2)  # (1,144,32,32)->(2304,2,16,2)->(2304,16,2,2)
    # [B * C * n_h, n_w, p_h, p_w] --> [BP, N, C] where P = p_h * p_w and N = n_h * n_w
    x = x.reshape(B, C, num_patches, self.patch_area).transpose(1, 3).reshape(B * self.patch_area, num_patches, -1)
    # (1,144,256,4)->(1,4,144,256)->(4,256,144)

    # Global representations
    x = self.transformer(x)  # (4,256,144)
    x = self.norm(x)

    # Fold (patch -> feature map)
    # [B, P, N, C] --> [B*C*n_h, n_w, p_h, p_w]
    x = x.contiguous().view(B, self.patch_area, num_patches, -1)  # (1,4,256,144)
    x = x.transpose(1, 3).reshape(B * C * num_patch_h, num_patch_w, patch_h, patch_w)  # (1,144,256,4)->(2304,16,2,2)
    # [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W]
    x = x.transpose(1, 2).reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)  # (2304,2,16,2)->(1,144,32,32)
    if interpolate:
        x = F.interpolate(x, size=(H, W), mode="bilinear", align_corners=False)

    x = self.conv_proj(x)  # 1x1-s1-96 conv, (1,96,32,32)
    if self.conv_fusion is not None:  # 3x3-s1-96
        x = self.conv_fusion(torch.cat((shortcut, x), dim=1))  # (1,96,32,32)
    return x

 

  • 17
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

00000cj

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值