EfficientViT:高分辨率密集预测的多尺度线性注意力

标题:EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction

论文:https://arxiv.org/abs/2205.14756

中文版:【读点论文】EfficientViT: Enhanced Linear Attention for High-Resolution Low-Computation将softmax注意力转变为线性注意力_羞儿的博客-CSDN博客

代码:https://codeload.github.com/mit-han-lab/efficientvit/zip/refs/heads/master

目录

一、摘要

二、主要贡献

三、方法论

3.1 Multi-Scale Linear Attention(多尺度线性注意力) 

3.2 EfficientViT架构

四、实验

4.1 消融研究

4.2 语义分割实验

五、总结


一、摘要

研究背景高分辨率密集预测使许多有吸引力的现实世界的应用,如计算摄影,自动驾驶等,然而,巨大的计算成本使得部署最先进的高分辨率密集预测模型的硬件设备上的困难

主要工作:本文提出了一种新的多尺度线性attention的高分辨率视觉模型——EfficientViT。与之前的高分辨率密集预测模型依赖于大量的softmax关注、硬件低效的大核卷积或复杂的拓扑结构来获得良好的性能不同,多尺度线性attention只需要轻量级和硬件高效的操作就能实现全局接受场和多尺度学习(高分辨率密集预测的两个理想特征)。

研究成果:因此,在各种硬件平台(包括移动CPU、边缘GPU和云GPU)上,EfficientViT比以前的最先进型号提供了显著的性能提升。在Cityscapes(数据集)上没有性能损失的情况下,EfficientViT分别比SegFormer和SegNeXt提供了高达13.9倍和6.2倍的GPU延迟减少。对于超分辨率,EfficientViT比Restormer提供高达6.4倍的加速,同时提供0.11dB的PSNR增益。

二、主要贡献

1. 引入了一个新的多尺度线性注意力模块,用于高效的高分辨率稠密预测。它实现了全局感受野多尺度学习同时保持了良好的硬件效率。据我们所知,我们的工作是第一个证明线性注意力对高分辨率密集预测的有效性

2. 我们设计了高效vit,一个新的高分辨率系列基于视觉模型,提出了多尺度线性注意模块

3. EfficientViT在不同硬件平台(移动的CPU,边缘GPU和云GPU)上的语义分割,超分辨率,分割任何东西和ImageNet分类方面都比以前的SOTA模型有显著的加速。

三、方法论

3.1 Multi-Scale Linear Attention(多尺度线性注意力) 

多尺度线性注意力仅通过硬件高效的操作同时实现了全局感受野和多尺度学习。基于多尺度线性注意力,作者提出了一种新的用于高分辨率密集预测的Vision transformer模型EfficientVit。  

动机:从性能角度来看,全局感受野和多尺度学习是必不可少的。以前的 SOTA 高分辨率密集预测模型通过启用这些特征提供了较强的性能,但不能提供良好的效率。多尺度线性注意力模块通过用轻微的性能损失换取显著的效率提升来解决这个问题。

方法使用ReLU线性注意力来实现全局感受野,而不是繁重的softmax注意力。

ReLU线性注意力的公式推导

由传统的softmax注意力公式和Relu注意力相似度计算函数(相似度计算函数替换为Relu版的),可得:

由矩阵乘法的结合律,可得:

推导最终结论:由公式(3)所示,只需要计算\in \mathbb{R}^{d\times1}一次,就可以对每个Query重用它们(多头attention机制查询无关问题的最终解???),从而只需要O(N)的计算代价和O(N)的内存。 

  

ReLU线性注意力的局限性:如下图所示,softmax 注意和 ReLU 线性注意的注意图。由于缺乏非线性相似函数,ReLU 线性注意不能生成集中的注意图,捕获局部信息的能力较弱。(ReLU线性注意力缺点暴露)

解决方案:

1. 为了减轻其局限性,我们提出用卷积增强 ReLU 线性注意力。具体来说,在每个 FFN 层中插入深度卷积。如下图所示,其中 ReLU 线性注意力捕获上下文信息,FFN+DWConv 捕获局部信息

2. 将邻近的 Q/K/V token信息聚合(拼接)成多尺度token以增强 ReLU 线性注意的多尺度学习能力这里多尺度是指通道方向上的不同尺度,所以聚合能多尺度学习能力)。

具体来说,将所有DWConv融合成单个DWConv组,将所有 1x1 Convs合并成单个1x1的卷积组,组数为3 × #head,每组通道数为d。得到多尺度token后,对其进行ReLU线性注意力,提取多尺度全局特征。最后,将特征沿头部维度进行连接,并将其提供给最终的线性层以融合特征。

(本质上是使用nn.Conv2d()函数中的groups参数,将输入和输出通道分成几组进行卷积操作,学习通道方向上的不同尺度的信息。)

Q:感受野和注意力机制有什么关系?

A:注意力机制可以通过计算不同位置之间的关系,来捕捉长距离依赖关系,从而扩大感受野,提高网络的感知能力。

代码如下

轻量权重多尺度注意力模块

# 轻量权重多尺度注意力
class LiteMLA(nn.Module):
    r"""Lightweight multi-scale linear attention"""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        heads: int or None = None,
        heads_ratio: float = 1.0,
        dim=8,
        use_bias=False,
        norm=(None, "bn2d"),
        act_func=(None, None),
        kernel_func="relu",
        scales: tuple[int, ...] = (5,),
        eps=1.0e-15,
    ):
        super(LiteMLA, self).__init__()
        self.eps = eps
        heads = heads or int(in_channels // dim * heads_ratio)

        total_dim = heads * dim

        use_bias = val2tuple(use_bias, 2)
        norm = val2tuple(norm, 2)
        act_func = val2tuple(act_func, 2)

        self.dim = dim
        self.qkv = ConvLayer(
            in_channels,
            3 * total_dim,
            1,
            use_bias=use_bias[0],
            norm=norm[0],
            act_func=act_func[0],
        )
        self.aggreg = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Conv2d(
                        3 * total_dim,
                        3 * total_dim,
                        scale,
                        padding=get_same_padding(scale),
                        groups=3 * total_dim,
                        bias=use_bias[0],
                    ),
                    nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]),
                )
                for scale in scales
            ]
        )              # nn.Conv2d()函数中的groups参数是指将输入和输出通道分成几组进行卷积操作
        self.kernel_func = build_act(kernel_func, inplace=False)    # Relu激活函数

        self.proj = ConvLayer(
            total_dim * (1 + len(scales)),
            out_channels,
            1,
            use_bias=use_bias[1],
            norm=norm[1],
            act_func=act_func[1],
        )

    @autocast(enabled=False)
    def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor:
        B, _, H, W = list(qkv.size())

        if qkv.dtype == torch.float16:
            qkv = qkv.float()

        qkv = torch.reshape(
            qkv,
            (
                B,
                -1,
                3 * self.dim,
                H * W,
            ),
        )
        qkv = torch.transpose(qkv, -1, -2)
        q, k, v = (
            qkv[..., 0 : self.dim],
            qkv[..., self.dim : 2 * self.dim],
            qkv[..., 2 * self.dim :],
        )

        # lightweight linear attention
        q = self.kernel_func(q)     # 进行relu激活
        k = self.kernel_func(k)     # 进行relu激活

        # linear matmul
        trans_k = k.transpose(-1, -2)

        v = F.pad(v, (0, 1), mode="constant", value=1)      # 进行维度扩展
        kv = torch.matmul(trans_k, v)       # 按推导公式计算
        out = torch.matmul(q, kv)
        out = out[..., :-1] / (out[..., -1:] + self.eps)

        out = torch.transpose(out, -1, -2)
        out = torch.reshape(out, (B, -1, H, W))
        return out

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # generate multi-scale q, k, v
        qkv = self.qkv(x)               # 获取Q、K、V,由1x1卷积得到
        multi_scale_qkv = [qkv]
        for op in self.aggreg:          # 卷积聚合,学习通道上的多尺度信息
            multi_scale_qkv.append(op(qkv))
        multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1)     # Q、K、V拼接

        out = self.relu_linear_att(multi_scale_qkv)     # 重新等分划分为Q,K,V,馈入ReLU线性注意力
        out = self.proj(out)        # 1x1卷积输出,模拟线性层

        return out

3.2 EfficientViT架构

如上图所示,

Backbone(骨干):由输入层和四个阶段组成,特征图大小逐渐减小,通道数量逐渐增加。在阶段3和4中插入EfficientViT模块。对于下采样,我们使用步幅为2的MBConv。

Head(分割头):P2、P3和P4表示阶段2、3和4的输出,形成特征图的金字塔。为了简单和高效,使用1x 1卷积和标准上采样操作(例如,双线性/双三次上采样)以匹配它们的空间和信道大小并经由加法来融合它们。简单的头部设计,其包括若干MBConv块和输出层(即,预测和上采样)。

  

代码如下

Backbone(骨干)

class EfficientViTBackbone(nn.Module):
    # Backbone:input_stem + stage1 + stage2 + stage3 + stage4
    def __init__(
            self,
            width_list: list[int],
            depth_list: list[int],
            in_channels=3,
            dim=32,
            expand_ratio=4,
            norm="bn2d",
            act_func="hswish",
    ) -> None:
        super().__init__()

        self.width_list = []
        # input stem
        self.input_stem = [
            ConvLayer(
                in_channels=3,
                out_channels=width_list[0],
                stride=2,
                norm=norm,
                act_func=act_func,
            )   # 3x3卷积 -> 下采2倍
        ]
        for _ in range(depth_list[0]):
            block = self.build_local_block(         # 构建DSConv模块,捕捉局部信息
                in_channels=width_list[0],
                out_channels=width_list[0],
                stride=1,
                expand_ratio=1,
                norm=norm,
                act_func=act_func,
            )
            self.input_stem.append(ResidualBlock(block, IdentityLayer()))       # 增加残差
        in_channels = width_list[0]
        self.input_stem = OpSequential(self.input_stem)         # 把input_stem阶段各模块按顺序添加到ModuleList中
        self.width_list.append(in_channels)         # 把每个模块的通道数添加到width_list

        # stages
        self.stages = []

        # # # stages1
        for w, d in zip(width_list[1:3], depth_list[1:3]):
            stage = []
            for i in range(d):
                stride = 2 if i == 0 else 1
                block = self.build_local_block(     # 构建MBConv模块,捕捉局部信息
                    in_channels=in_channels,
                    out_channels=w,
                    stride=stride,
                    expand_ratio=expand_ratio,
                    norm=norm,
                    act_func=act_func,
                )
                block = ResidualBlock(block, IdentityLayer() if stride == 1 else None)      # 增加残差
                stage.append(block)
                in_channels = w
            self.stages.append(OpSequential(stage))
            self.width_list.append(in_channels)

        for w, d in zip(width_list[3:], depth_list[3:]):
            stage = []

            # # # stages2
            block = self.build_local_block(     # 构建MBConv模块,捕捉局部信息
                in_channels=in_channels,
                out_channels=w,
                stride=2,
                expand_ratio=expand_ratio,
                norm=norm,
                act_func=act_func,
                fewer_norm=True,
            )
            stage.append(ResidualBlock(block, None))
            in_channels = w

            # # # stages3、4
            for _ in range(d):
                stage.append(
                    EfficientViTBlock(      # EfficientViTBlock模块,多尺度注意力提取上下文特征
                        in_channels=in_channels,
                        dim=dim,
                        expand_ratio=expand_ratio,
                        norm=norm,
                        act_func=act_func,
                    )
                )
            self.stages.append(OpSequential(stage))
            self.width_list.append(in_channels)
        self.stages = nn.ModuleList(self.stages)    # nn.ModuleList,用于存储不同的模块,并自动将每个模块的参数添加到网络中

    # 构建DSConv 或 MBConv —> 局部信息
    @staticmethod
    def build_local_block(
            in_channels: int,
            out_channels: int,
            stride: int,
            expand_ratio: float,
            norm: str,
            act_func: str,
            fewer_norm: bool = False,
    ) -> nn.Module:
        if expand_ratio == 1:
            block = DSConv(                 # DSConv模块
                in_channels=in_channels,
                out_channels=out_channels,
                stride=stride,
                use_bias=(True, False) if fewer_norm else False,
                norm=(None, norm) if fewer_norm else norm,
                act_func=(act_func, None),
            )
        else:
            block = MBConv(                 # MBConv模块,Mobile倒置残差瓶颈卷积 -> 2倍下采样
                in_channels=in_channels,
                out_channels=out_channels,
                stride=stride,
                expand_ratio=expand_ratio,
                use_bias=(True, True, False) if fewer_norm else False,
                norm=(None, None, norm) if fewer_norm else norm,
                act_func=(act_func, act_func, None),
            )
        return block

    def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
        output_dict = {"input": x}
        output_dict["stage0"] = x = self.input_stem(x)
        for stage_id, stage in enumerate(self.stages, 1):   # 网络的backbone
            output_dict["stage%d" % stage_id] = x = stage(x)
        output_dict["stage_final"] = x
        return output_dict

DSConv模块

class DSConv(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size=3,
        stride=1,
        use_bias=False,
        norm=("bn2d", "bn2d"),
        act_func=("relu6", None),
    ):
        super(DSConv, self).__init__()

        use_bias = val2tuple(use_bias, 2)
        norm = val2tuple(norm, 2)
        act_func = val2tuple(act_func, 2)

        self.depth_conv = ConvLayer(
            in_channels,
            in_channels,
            kernel_size,
            stride,
            groups=in_channels,
            norm=norm[0],
            act_func=act_func[0],
            use_bias=use_bias[0],
        )
        self.point_conv = ConvLayer(
            in_channels,
            out_channels,
            1,
            norm=norm[1],
            act_func=act_func[1],
            use_bias=use_bias[1],
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.depth_conv(x)
        x = self.point_conv(x)
        return x

MBConv模块

# MBConv
class MBConv(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size=3,
        stride=1,
        mid_channels=None,
        expand_ratio=6,
        use_bias=False,
        norm=("bn2d", "bn2d", "bn2d"),
        act_func=("relu6", "relu6", None),
    ):
        super(MBConv, self).__init__()

        use_bias = val2tuple(use_bias, 3)
        norm = val2tuple(norm, 3)
        act_func = val2tuple(act_func, 3)
        mid_channels = mid_channels or round(in_channels * expand_ratio)

        self.inverted_conv = ConvLayer(
            in_channels,
            mid_channels,
            1,
            stride=1,
            norm=norm[0],
            act_func=act_func[0],
            use_bias=use_bias[0],
        )
        self.depth_conv = ConvLayer(
            mid_channels,
            mid_channels,
            kernel_size,
            stride=stride,
            groups=mid_channels,
            norm=norm[1],
            act_func=act_func[1],
            use_bias=use_bias[1],
        )
        self.point_conv = ConvLayer(
            mid_channels,
            out_channels,
            1,
            norm=norm[2],
            act_func=act_func[2],
            use_bias=use_bias[2],
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.inverted_conv(x)       # 512
        x = self.depth_conv(x)          # 512
        x = self.point_conv(x)          # 256
        return x

 EfficientViTBlock模块

# EfficientViTBlock模块 —> 提取上下文特征
class EfficientViTBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        heads_ratio: float = 1.0,
        dim=32,
        expand_ratio: float = 4,
        norm="bn2d",
        act_func="hswish",
    ):
        super(EfficientViTBlock, self).__init__()
        self.context_module = ResidualBlock(
            LiteMLA(                            # 轻量权重多尺度注意力
                in_channels=in_channels,
                out_channels=in_channels,
                heads_ratio=heads_ratio,
                dim=dim,
                norm=(None, norm),
            ),
            IdentityLayer(),
        )
        local_module = MBConv(
            in_channels=in_channels,
            out_channels=in_channels,
            expand_ratio=expand_ratio,
            use_bias=(True, True, False),
            norm=(None, None, norm),
            act_func=(act_func, act_func, None),
        )
        self.local_module = ResidualBlock(local_module, IdentityLayer())       # 添加残差连接

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.context_module(x)      # 轻量多尺度注意力 -> 全局上下文特征
        x = self.local_module(x)        # 深度卷积 -> 局部特征
        return x

  

四、实验

数据集:Cityscapes 和 ADE20K数据集。

评价指标:mIoU、Params和MAC(乘加累积操作数)。

4.1 消融研究

(1)EfficientViT模块的性能测试

mIoU和MAC在Cityscapes上测量,输入分辨率为1024x2048。重新调整模型的宽度,使它们具有相同的MAC,由上表所示,多尺度学习和全局感受野对于获得良好的语义分割性能至关重要。

(2)ImageNet上的主干性能对比

EfficientViT-L2-r384在ImageNet上获得了86.0的top-1精度,比EfficientNetV 2-L提供了+0.3的精度增益,在A100 GPU上提供了2.6倍的加速。

4.2 语义分割实验

与先进语义分割模型在Cityscapes数据集上的对比。

与SegFormer相比,EfficientViT在mIoU更高的边缘GPU(Jetson AGX Orin)上获得了高达13倍的MAC数节省和高达8.8倍的延迟减少。与SegNeXt相比,EfficientViT在边缘GPU上提供高达2.0倍的MAC减少和3.8倍的加速,同时保持更高的mIoU。 

五、总结

1. 本文针对高分辨率稠密预测的有效架构设计,引入了一个轻量级的多尺度注意力模块,它同时实现了全局感受野,以及具有轻量级和硬件高效操作的多尺度学习,从而在各种硬件设备上提供了显着的加速,而不会比SOTA高分辨率密集预测模型带来性能损失。

2. 多尺度线性注意力,使用ReLU线性注意力来实现全局感受野,通过FFN+DWConv 捕获局部信息和卷积聚合捕获多尺度信息,以此克服ReLU线性注意力轻量化所带来的缺点。

### EfficientViT 深度学习模型架构特点 EfficientViT作为一种基于Transformer架构的轻量级视觉模型,具备强大的特征提取能力和良好的泛化性能[^1]。其核心在于多尺度线性注意力机制的设计,这使得在高分辨率密集预测任务中能够同时达到高性能与高效的资源利用率。 #### 多尺度线性注意力模块 为了应对高分辨率输入带来的挑战,EfficientViT引入了多尺度线性注意力模块。这一设计不仅降低了传统自注意机制下的计算复杂度,还允许更大范围内的像素间关联建模,进而提高了对于细节捕捉的能力[^2]。 #### 特征融合网络(FFN)增强 除了改进注意力机制外,EfficientViT在网络内部增加了更多前馈神经网络(Feed-Forward Network, FFN)层的数量,并将其放置于单一自注意力层前后位置。这样的安排有效地减少了由于过多依赖自注意力而导致的时间开销问题;与此同时,通过加深FFNs层数促进了不同特征通道间的交流互动[^3]。 #### 局部结构信息保留 值得注意的是,在每一个FFN单元之前加入了特殊的令牌交互层——即深度可分离卷积(DWConv),以此方式向整个体系注入有关局部几何特性的先验假设。此改动有助于加强最终输出结果的空间一致性以及语义连贯性。 --- ### 应用场景实例分析 考虑到上述特性,EfficientViT非常适合应用于那些需要兼顾精度和速度的任务领域: - **物体检测**:借助出色的表征学习能力快速定位并分类图像中的目标对象; - **分割任务**:凭借优秀的上下文理解力精确划分各个区域边界; - **超分辨率重建(SR)**:虽然具体提到的是另一种称为SRFormer的工作专门针对此类应用场景做了优化调整[^4],但是鉴于两者都属于Vision Transformer家族成员之一的事实,可以推测EfficientViT同样适用于解决类似的计算机视觉难题。 ```python import torch from efficientvit import build_efficient_vit_model # 假设这是官方库的一部分 model = build_efficient_vit_model(pretrained=True) # 加载预训练权重 (如果有的话) if pretrained_weights_path is not None: model.load_state_dict(torch.load(pretrained_weights_path)) input_tensor = ... # 准备好待推理的数据张量 output = model(input_tensor) ```
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

向岸看

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值