苹果公司提出Mobile-ViT | 更小更轻精度更高,MobileNets或成为历史

在这里插入图片描述
MobileviT是一个用于移动设备的轻量级通用可视化Transformer,据作者介绍,这是第一次基于轻量级CNN网络性能的轻量级ViT工作,性能SOTA!。性能优于MobileNetV3、CrossviT等网络。

1.简介

轻量级卷积神经网络(CNN)是移动视觉任务的实际应用,它们的空间归纳偏差允许它们在不同的视觉任务中以较少的参数学习表征.然而,这些网络在空间上是局部.为了学习全局表征,采用基于自注意力的Vision Transformer(ViTs).
在本文中,本文提出了以下问题:是否有可能结合CNN和ViT的优势,构建一个轻量级 低延迟的移动视觉任务网络?
为此提出了MobileViT,一种轻量级的、通用的移动设备Vision Transformer。MobileViT提出了一个不同的视角,以Transformer作为卷积处理信息。
在这里插入图片描述

实验结果表明,在不同的任务和数据集上.MobileViT显著优于基于CNN和ViT的网络.
在ImageNet-1k数据集上,MobileViT在大约600万个参数的情况下达到了78.4%的Top-1准确率,对于相同的数量的参数,比如Mobilenetv3和deit的准确率分别高出3.2%和6.2%.

在MS-COCO目标检测任务中,在参数数量相近的情况下,mobilevit比mobilentv3的准确率高出5.7%.

2.相关工作

2.1轻量化CNN模型

CNN的基本构建层是标准的卷积层,由于这一层的计算成本很高,人们提出了几种基于因子分解的方法,使其变得轻量化以方便移动设备的部署.
其中,深度可分离卷积引起了人民的兴趣,并被广泛应用于最先进的轻量级CNN移动视觉任务,包括Mobilenets,shuffenetv2,espnetv2,mixnet和mnasnet.这些轻量级CNN是多功能的,易于训练.
例如,这些网络可以很容易地取代现有特定任务模型(如deeplabv3)中的主干网络(如resnet),以减少网络规模并降低延迟.尽管有这些好处,但是这些方法的一个主要缺点是它们在空间上局部的.
这项工作将transformer视为卷积,允许利用卷积和transformer(例如,全局处理)的优点来构建轻量级和通用ViT模型.

2.2 Vision Transformer

Vision Transformer应用于大尺度图像识别,结果表明,在超大尺度数据集(如JFT-300M)下,ViTs可以实现CNN级的精度,而不存在图像特异性的归纳偏差.

通过广泛的数据增强,大量的L2正则化和蒸馏,可以在ImageNet数据集上训练ViT,以实现CNN级别的性能.然而,与CNN不同的是ViT的优化性能不佳,而且很难训练.

例如,ViT-C为ViT前阶段增加了一个卷积Backbone.

CvT改进了transformer的Multi-head attention,并使用深度可卷积代替线性投影.

botnet用multi-head attention 取代了resnet bottleneck unit的标准卷积.

conViT采用gated positional self-attention的soft convolutional归纳偏差.

PiT使用基于深度卷积的池化层拓展了ViT.

虽然这些模型使用data augmentation可以达到与CNN差不多的性能,但是这些模型大多数是heavy-weight。例如,PiT和CvT比EfficientNet多学习6.1倍和1.7倍的参数,在imagenet-1k数据集上取得了相似的性能(top-1的准确度为81.6%)

此外,**当这些模型被缩小以构建轻量级ViT模型时,它们的性能明显比轻量级CNN性能差.**对于大约600万的参数预算,PiT的imagenet-1k精度比mobilenetv3低2.2%.

2.3 讨论

与普通的ViT相比,将卷积和transformer相结合可以得到鲁棒的高性能ViT.然而,一个开放的问题就是:如何结合卷积和transformer来构建轻量级网络的移动视觉任务?

这篇文章的重点是设计轻量的ViT模型,通过简单训练胜过最先进的模型.为此,作者设计了mobilevit,它结合了CNN和ViT的优势构建了一个轻量级,通用和移动设备友好的网络模型.

MobileViT带来了一些新的结果:
1.更好的性能:在相同的参数情况下,余现有的轻量级CNN相比,mobilevit模型在不同的移动视觉任务中实现了更好的性能.
2.更好的泛化能力:泛化能力是指训练和评价指标之间的差距.对于具有相似的训练指标的两个模型,具有更好评价指标的模型更具有通用性,因为它可以更好地预测未见的数据集.与CNN相比,即使有广泛的数据增强,其泛化能力也很差,mobilevit显示出更好的泛化能力(如下图).
3.更好的鲁棒性:一个好的模型应该对超参数具有鲁棒性,因为调优这些超参数会消耗时间和资源.与大多数基于ViT的模型不同,mobilevit模型使用基于增强训练,与L2正则化不太敏感.

3.mobile-vit

3.1问题阐述

在这里插入图片描述
在上图,一个标准的ViT模型将输入reshape为patches,将其投影到固定的d维空间,然后使用L个transformer block学习patches之间的表征.

vision transformer由于这些模型忽略了CNN模型固定的空间归纳偏差,所以它们需哟啊更多的参数来学习视觉表征.例如,与CNN的网络deeplabv3相比,基于vit的网络dpt多学习了6倍的参数才可以提供相似的分割性能(DPT vs DeepLabv3:345 M vs. 59 M ).此外,与CNN相比,这些模型的优化性能不佳.这些模型对L2正则化很敏感,需要大量的数据增强以防止过拟合.

3.2 mobile-block

在这里插入图片描述
mobilevit block如上图所示.其目的是用较少的参数对输入张量中的局部和全局信息进行建模.

形式上,对于一个给定的输入张量,mobilevit首先应用了一个nn的标准卷积层,然后用一个一个点(或者11)卷积层产生特征.n*n 卷积层编码局部空间信息,而点卷积通过学习输入通道的线性组合将张量投影到高维空间(d维,其中d>c)

mobilevit
通过mobilevit,希望在拥有有效感受野的同时,对远距离非局部依赖进行建模.一种被广泛研究的建模远程依赖关系的方法是扩张卷积.然而,这种方法需要谨慎选择膨胀率.否则,权重将应用于填充的零而不是有效的空间区域.

另一个有希望的解决方案是self-attention.在self-attention方法中,具有multi-head self-attention的vision transformers(ViTs)在视觉识别任务中是有效的。然而,vit是heavy-weight,并由于vit缺乏空间归纳偏差,表现出较差的可优化性。

为了使MobileViT能够学习具有空间归纳偏差的全局表示,作者将展开为N个non-overlapping flattened patches 。其中P=wh,N=HW/P为patch的个数,h≤N, w≤N分别为patch的高度和宽度。,通过应用Transformer来编码patch间的关系:

在这里插入图片描述
与丢失像素空间顺序的vit不同,mobilevit既不丢失patch顺序,也不丢失每个patch内像素的空间顺序.因此,可以将折叠得到.然后,通过逐点卷积投影到低维空间(c维),并通过cat操作与x结合.

然后使用另一个nn卷积层来融合级联张量中的局部和全局特征.注意,因为使用卷积对nn区域的局部信息进行编码,而对于第p个位置的p个patch对全局信息进行编码.每一个像素可以对x中所有像素的信息进行编码,如下图所示,因此,mobilevit的整体有效接受域为H*W.
在这里插入图片描述在MobileViT Block中,每个像素都可以感知到其他像素

在上图中,红色像素通过transformer处理蓝色像素(其他patch中相应位置的像素)。因为蓝色像素已经使用卷积对邻近像素的信息进行了编码,这就允许红色像素对图像中所有像素的信息进行编码.在这里,黑色和灰色网格中的每个单元分别表示一个patch和一个像素.

3.2.1 与CNN的关系

标准卷积可以看作是3个顺序操作的堆叠:

  1. 展开
  2. 矩阵乘法(学习局部表示)
  3. 折叠
    mobilevit block与卷积相似,因为它们也利用了相同的构建块,mobilevit block将卷积层中的局部处理(矩阵乘法)代替为更深层次的全局处理(一个transformer层堆栈).因此,mobilevit具有类似于卷积的属性(例如,空间偏差).因此,mobilevit可以看作是transformer的卷积.

这里设计的简单的一个优势是,卷积和transformer的低层次高效实现可以开箱即用
允许在不同的设备上使用mobilevit而不需要任何额外的负担.

3.2.2 为什么是light-weight?

mobilevit block使用标准卷积和transformer分别学习局部和全局表示.因为之前的工作表明,使用这些层设计的网络是heavy-weight,一个自然的问题出现了:为什么mobilevit是轻量级的?

作者认为问题主要在于通过transformer学习全局表示.对于给定的patch,之前的工作是通过学习像素的线性组合将空间信息转化成潜在信息.然后,通过使用transformer对全局信息进行编码学习patch之间信息.因此,这些模型失去了图像特定的归纳偏差(CNN模型固定的特点)

因此,需要更多的参数来学习视觉表征.所以,那些模型又深又宽.与这些模型不同的是,mobilevit使用卷积和transformer的方式是,生成的mobilevit block具有类似卷积的属性,同时允许全局处理.这种建模能力能够设计出浅而窄的mobilevit模型,从而使得weight更轻.

与基于ViT的DeiT模型(L=12和d=192)相比,MobileViT模型在空间层面分别使用了32×32、16×16和8×8的和。由此产生的MobileViT网络比DeiT网络更快(1.85×),更小(2×),更好(+1.8%)。

3.2.3 计算复杂度

理论上,与vit相比,MobileViT效率较低。然而,在实践中,MobileViT比vit更有效率。

3.2.4 MobileViT架构

受到轻量级CNN的启发。作者用3种不同的网络规模(S:小,XS:特别小,XXS:特别小)训练MobileViT模型,这3种网络规模通常用于移动视觉任务(图3c)。
在这里插入图片描述
MobileViT的初始层是一个stride=3×3的标准卷积,其次是MobileNetv2(或MV2) Block和MobileViT Block。使用Swish作为激活函数。跟随CNN模型,在MobileViT块中使用n=3。

特征图的空间维数通常是2和h的倍数。因此,在所有空间层面设h=w=2。MobileViT网络中的MV2块主要负责降采样。因此,这些块在MobileViT网络中是浅而窄的。下图中MobileViT的Spatiallevel-wise参数分布进一步说明了在不同的网络配置中,MV2块对总网络参数的影响非常小。

在这里插入图片描述

3.2.5 多尺度采样训练

在基于vit的模型中,学习多尺度表示的标准方法是微调.例如,在不同尺寸上对经过224×224空间分辨率训练的DeiT模型进行了独立微调。由于位置嵌入需要根据输入大小进行插值,而网络的性能受插值方法的影响,因此这种学习多尺度表示的方法对vit更有利。与CNN类似,MobileViT不需要任何位置嵌入,它可以从训练期间的多尺度输入中受益。

先前基于CNN的研究表明,多尺度训练是有效的。然而,大多数都是经过固定次数的迭代后获得新的空间分辨率。

例如,YOLOv2在每10次迭代时从预定义的集合中采样一个新的空间分辨率,并在训练期间在不同的gpu上使用相同的分辨率。这导致GPU利用率不足和训练速度变慢,因为在所有分辨率中使用相同的批大小(使用预定义集中的最大空间分辨率确定)。

在这里插入图片描述

为了便于MobileViT在不进行微调的情况下学习多尺度表示,并进一步提高训练效率(即更少的优化更新),作者将多尺度训练方法扩展到可变大小的Batch-Size。给定一组排序的空间分辨率和Batch-Size b,最大空间分辨率为,在每个GPU上随机采样空间分辨率,并计算第t次迭代的Batch-Size为:。因此,更大的Batch-Size用于更小的空间分辨率。这减少了优化器每个epoch的更新,有助于更快的训练。

在这里插入图片描述
上图比较了标准采样器和多尺度采样器。在这里,将PyTorch中的distributed-dataparelles称为标准采样器。总体而言,多尺度采样器:

  1. 减少了训练时间,因为它需要更少的优化器更新不同大小的batch-size
  2. 提高了约0.5%的性能(下图);
    在这里插入图片描述
    注释:促使网络学习更好的多尺度表征,即在不同的空间分辨率下评估相同的网络,与使用标准采样器训练的网络相比,具有更好的性能。

pytorch复现代码如下:
在这里插入图片描述
作者还进行了通用化实验说明了多尺度采样器是通用的,并改善了CNN如MobileNetv2)的性能
在这里插入图片描述

4 实验结果

4.1 ImageNet-1K

在这里插入图片描述
图4.1 Light-CNNs参数对比
在这里插入图片描述
图4.2 Light-CNNs精度对比
在这里插入图片描述
图4.3 CNNs精度对比
图4.1显示了MobileViT在不同网络规模(MobileNetv1、MobileNetv2、ShuffleNetv2、ESPNetv2和MobileNetv3)上的性能优于轻量级CNN。例如,对于一个大约有250万个参数的模型(图4.2),在ImageNet-1k验证集上,MobileViT比MobileNetv2、ShuffleNetv2和MobileNetv3的性能分别高出5.0%、5.4%和7.4%。

图4.3进一步显示,MobileViT提供了比Heavy-weight CNN(ResNet, DenseNet, ResNet-se和EfficientNet)更好的性能。例如,对于相同数量的参数,MobileViT比effentnet的准确率高出2.1%。

在这里插入图片描述
图4.4 ViTs参数对比
在这里插入图片描述
图4.5 ViTs精度对比
图4.4比较了MobileViT和在ImageNet-1k数据集上从头开始训练的ViT变体(DeIT、T2T、PVT、CAIT、DeepViT、CeiT、CrossViT、LocalViT、PiT、ConViT、ViL、BoTNet和Mobile-former)。

不像ViT变体,显著受益于高级的数据增强(例如,PiT w/ basic vs. advanced: 72.4%(R4) vs. 78.1%(R17);图4.5), MobileViT通过更少的参数和基本的增强实现了更好的性能。例如,MobileViT比DeiT小2.5,好2.6%。

总的来说,这些结果表明,与CNN相似,MobileViTs易于优化和鲁棒性强。因此,它们可以很容易地应用于新的任务和数据集。

4.2 目标检测任务

在这里插入图片描述
图4.6 Light-CNN对比
在这里插入图片描述
图4.6 Heavy-CNN对比
图4.6显示,在320×320的相同输入分辨率下,基于MobileViT的SSDLite与其他轻型CNN模型(MobileNetv1、MobileNetv2、MobileNetv3、MNASNet和MixNet)相比,性能更好。

例如,当使用MobileViT而不是MNASNet作为Backbone时,SSDLite的性能提高了1.8%,其模型尺寸减少了1.8×。此外,基于MobileViT的SSDLite性能优于Heavy-CNN Backbone的标准SSD-300,同时学习的参数明显更少。

4.3 语义分割任务

在这里插入图片描述
图4.8 语义分割结果对比
图4.8显示了带有MobileViT的DeepLabv3更小更好。使用MobileViT代替MobileNetv2作为Backbone时,DeepLabv3的性能提高了1.4%,体积减少了1.6×。此外,MobileViT提供了具有竞争力的性能与模型renet-101相比,所需参数减少了9倍。

5 代码

import torch
import torch.nn as nn

from einops import rearrange

class SiLU(torch.nn.Module):  # export-friendly version of nn.SiLU()
    @staticmethod
    def forward(x):
        return x * torch.sigmoid(x)
def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        SiLU()
    )


def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
    return nn.Sequential(
        nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        SiLU()
    )


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),

            SiLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b p h n d -> b p n (h d)')
        return self.to_out(out)


class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x


class MV2Block(nn.Module):
    def __init__(self, inp, oup, stride=1, expansion=4):
        super().__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(inp * expansion)
        self.use_res_connect = self.stride == 1 and inp == oup

        if expansion == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                SiLU(),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                SiLU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


class MobileViTBlock(nn.Module):
    def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
        super().__init__()
        self.ph, self.pw = patch_size

        self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
        self.conv2 = conv_1x1_bn(channel, dim)

        self.transformer = Transformer(dim, depth, 1, 32, mlp_dim, dropout)

        self.conv3 = conv_1x1_bn(dim, channel)
        self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)

    def forward(self, x):
        y = x.clone()
        # Local representations
        x = self.conv1(x)
        x = self.conv2(x)
        # Global representations
        _, _, h, w = x.shape
        x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
        x = self.transformer(x)
        x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph,
                      pw=self.pw)
        # Fusion
        x = self.conv3(x)
        x = torch.cat((x, y), 1)  # 合并
        x = self.conv4(x)   # 卷积成输入的尺寸
        return x


class MobileViT(nn.Module):
    def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)):
        super().__init__()
        ih, iw = image_size
        ph, pw = patch_size
        assert ih % ph == 0 and iw % pw == 0

        L = [2, 4, 3]

        self.conv1 = conv_nxn_bn(3, channels[0], stride=2)

        self.mv2 = nn.ModuleList([])
        self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))
        self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))
        self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
        self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))  # Repeat
        self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))
        self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))
        self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))

        self.mvit = nn.ModuleList([])
        self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2)))
        self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4)))
        self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4)))

        self.conv2 = conv_1x1_bn(channels[-2], channels[-1])

        self.pool = nn.AvgPool2d(ih // 32, 1)
        self.fc = nn.Linear(channels[-1], num_classes, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = self.mv2[0](x)

        x = self.mv2[1](x)
        x = self.mv2[2](x)
        x = self.mv2[3](x)  # Repeat

        x = self.mv2[4](x)
        x = self.mvit[0](x)

        x = self.mv2[5](x)
        x = self.mvit[1](x)

        x = self.mv2[6](x)
        x = self.mvit[2](x)
        x = self.conv2(x)

        x = self.pool(x).view(-1, x.shape[1])
        x = self.fc(x)
        return x


def mobilevit_xxs():
    dims = [64, 80, 96]
    channels = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320]
    return MobileViT((256, 256), dims, channels, num_classes=1000, expansion=2)


def mobilevit_xs():
    dims = [96, 120, 144]
    channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384]
    return MobileViT((256, 256), dims, channels, num_classes=1000)


def mobilevit_s():
    dims = [144, 192, 240]
    channels = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640]
    return MobileViT((256, 256), dims, channels, num_classes=1000)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


if __name__ == '__main__':
    img = torch.randn(5, 3, 256, 256)

    vit = mobilevit_xxs()
    out = vit(img)
    print(out.shape)
    print(count_parameters(vit))

    vit = mobilevit_xs()
    out = vit(img)
    print(out.shape)
    print(count_parameters(vit))

    vit = mobilevit_s()
    out = vit(img)
    print(out.shape)
    print(count_parameters(vit))
  • 4
    点赞
  • 37
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值