(即插即用模块-Attention部分) 二十六、(ICCV 2023) MSLA 多尺度线性注意力

在这里插入图片描述

paper:EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction

Code:https://github.com/mit-han-lab/efficientvit


1、Multi-Scale Linear Attention

现有模型的局限性存在以下短处:计算成本高: 现有的高分辨率密集预测模型往往依赖于复杂的模型结构,例如 softmax 注意力机制、大卷积核等,这会导致计算成本高昂,难以在硬件设备上部署。性能提升有限: 一些轻量级的模型虽然计算成本较低,但性能提升有限,难以满足实际应用的需求。为了解决现有高分辨率密集预测模型在效率和性能之间的权衡问题。这篇论文提出一种新的多尺度线性注意力(Multi-Scale Linear Attention)。与以往的高分辨率稠密预测模型依赖于繁重的softmax注意力、硬件效率低的大核卷积或复杂的拓扑结构来获得良好性能不同,多尺度线性注意力仅需轻量级和硬件效率高的操作即可实现全局感受野和多尺度学习。

MSLA 的核心思想:全局感受野: 通过 ReLU 线性注意力机制,MSLA能够有效地聚合来自全局的信息,从而获得全局感受野,这对于高分辨率密集预测任务至关重要。多层次学习: 通过对 Q/K/V 向量进行小卷积核的聚合,MSLA 模块能够生成多尺度向量,从而实现多层次学习,捕获不同尺度的特征信息。

对于一个输入 X,MSLA 的实现原理:

  1. 输入特征图投影:输入特征图经过一个线性投影层,分别投影到 Q (Query), K (Key), V (Value) 向量。

  2. 多尺度特征生成:对 Q/K/V 向量进行分组操作,每组包含多个向量。对每个分组中的向量进行小卷积核的深度可分离卷积 (DWConv),生成多尺度特征图。

  3. ReLU 线性注意力:对多尺度向量进行 ReLU 线性注意力,提取全局特征信息。与 softmax 注意力机制相比,ReLU 线性注意力机制的计算复杂度更低,且不需要进行 softmax 操作,从而提高了硬件效率。

  4. 特征融合:最后将注意力机制输出与原始特征图进行融合,得到最终的输出特征图。输出特征图包含了不同尺度的全局特征信息,可以用于后续的任务,例如语义分割、超分辨率等。


Multi-Scale Linear Attention 结构图:
在这里插入图片描述

2、EfficientViT

基于 MSLA,论文提出一种新架构 EfficientViT,EfficientViT 是一种高效的高分辨率视觉模型,用于密集预测任务。EfficientViT 遵循标准的骨干网络设计,包含输入 Stem 和四个阶段:

  1. Input Stem:最开始的 Input Stem 由 卷积+DSConv组成。
  2. Stage:一般的 stage 位于第一, 二阶段,是由 MBConv 组成。
  3. EfficientViT Module:EfficientViT Module 被插入到第三,四阶段,是 EfficientViT 的基本构建块,由 MSLA 模块和一个 FFN+DWConv 层组成。MSLA 是 EfficientViT 的核心模块,负责提取全局特征信息和多层次特征信息。而 FFN+DWConv 层负责捕获局部特征信息。
  4. Head:最后,Head 用于处理 Backbone 的输出特征图,并生成最终的预测结果。Head 由多个 MBConv 块和输出层组成,例如预测层和上采样层。

EfficientViT 结构图:
在这里插入图片描述

3、代码实现

import math
import torch
import torch.nn as nn
from typing import Tuple
from functools import partial
from inspect import signature
import torch.nn.functional as F


def build_kwargs_from_config(config, target_func):
    valid_keys = list(signature(target_func).parameters)
    kwargs = {}
    for key in config:
        if key in valid_keys:
            kwargs[key] = config[key]
    return kwargs


class LayerNorm2d(nn.LayerNorm):
    def forward(self, x):
        out = x - torch.mean(x, dim=1, keepdim=True)
        out = out / torch.sqrt(torch.square(out).mean(dim=1, keepdim=True) + self.eps)
        if self.elementwise_affine:
            out = out * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
        return out


REGISTERED_ACT_DICT = {
    "relu": nn.ReLU,
    "relu6": nn.ReLU6,
    "hswish": nn.Hardswish,
    "silu": nn.SiLU,
    "gelu": partial(nn.GELU, approximate="tanh"),
}


REGISTERED_NORM_DICT = {
    "bn2d": nn.BatchNorm2d,
    "ln": nn.LayerNorm,
    "ln2d": LayerNorm2d,
}


def build_act(name, **kwargs):
    if name in REGISTERED_ACT_DICT:
        act_cls = REGISTERED_ACT_DICT[name]
        args = build_kwargs_from_config(kwargs, act_cls)
        return act_cls(**args)
    else:
        return None


def build_act(name, **kwargs):
    if name in REGISTERED_ACT_DICT:
        act_cls = REGISTERED_ACT_DICT[name]
        args = build_kwargs_from_config(kwargs, act_cls)
        return act_cls(**args)
    else:
        return None


def get_same_padding(kernel_size):
    if isinstance(kernel_size, tuple):
        return tuple([get_same_padding(ks) for ks in kernel_size])
    else:
        assert kernel_size % 2 > 0, "kernel size should be odd number"
    return kernel_size // 2


def val2tuple(x, min_len=1, idx_repeat=0):
    x = val2list(x)

    # repeat elements if necessary
    if len(x) > 0:
        x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]

    return tuple(x)


def val2list(x, repeat_time=1):
    if isinstance(x, (list, tuple)):
        return list(x)
    return [x for _ in range(repeat_time)]


def build_norm(name="bn2d", num_features=None, **kwargs):
    if name in ["ln", "ln2d", "trms2d"]:
        kwargs["normalized_shape"] = num_features
    else:
        kwargs["num_features"] = num_features
    if name in REGISTERED_NORM_DICT:
        norm_cls = REGISTERED_NORM_DICT[name]
        args = build_kwargs_from_config(kwargs, norm_cls)
        return norm_cls(**args)
    else:
        return None


class ConvLayer(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size=3,
        stride=1,
        dilation=1,
        groups=1,
        use_bias=False,
        dropout=0,
        norm="bn2d",
        act_func="relu",
    ):
        super(ConvLayer, self).__init__()

        padding = get_same_padding(kernel_size)
        padding *= dilation

        self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=(kernel_size, kernel_size),
            stride=(stride, stride),
            padding=padding,
            dilation=(dilation, dilation),
            groups=groups,
            bias=use_bias,
        )
        self.norm = build_norm(norm, num_features=out_channels)
        self.act = build_act(act_func)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.dropout is not None:
            x = self.dropout(x)
        x = self.conv(x)
        if self.norm:
            x = self.norm(x)
        if self.act:
            x = self.act(x)
        return x


class LiteMLA(nn.Module):

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        heads_ratio: float = 1.0,
        dim=8,
        use_bias=False,
        norm=(None, "bn2d"),
        act_func=(None, None),
        kernel_func="relu",
        scales: Tuple[int, ...] = (5,),
        eps=1.0e-15,
    ):
        super(LiteMLA, self).__init__()
        self.eps = eps
        heads = int(in_channels // dim * heads_ratio)

        total_dim = heads * dim

        use_bias = val2tuple(use_bias, 2)
        norm = val2tuple(norm, 2)
        act_func = val2tuple(act_func, 2)

        self.dim = dim
        self.qkv = ConvLayer(
            in_channels,
            3 * total_dim,
            1,
            use_bias=use_bias[0],
            norm=norm[0],
            act_func=act_func[0],
        )
        self.aggreg = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Conv2d(
                        3 * total_dim,
                        3 * total_dim,
                        scale,
                        padding=get_same_padding(scale),
                        groups=3 * total_dim,
                        bias=use_bias[0],
                    ),
                    nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]),
                )
                for scale in scales
            ]
        )
        self.kernel_func = build_act(kernel_func, inplace=False)

        self.proj = ConvLayer(
            total_dim * (1 + len(scales)),
            out_channels,
            1,
            use_bias=use_bias[1],
            norm=norm[1],
            act_func=act_func[1],
        )

    @torch.autocast(device_type="cuda", enabled=False)
    def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor:
        B, _, H, W = list(qkv.size())

        if qkv.dtype == torch.float16:
            qkv = qkv.float()

        qkv = torch.reshape(
            qkv,
            (
                B,
                -1,
                3 * self.dim,
                H * W,
            ),
        )
        q, k, v = (
            qkv[:, :, 0 : self.dim],
            qkv[:, :, self.dim : 2 * self.dim],
            qkv[:, :, 2 * self.dim :],
        )

        # lightweight linear attention
        q = self.kernel_func(q)
        k = self.kernel_func(k)

        # linear matmul
        trans_k = k.transpose(-1, -2)

        v = F.pad(v, (0, 0, 0, 1), mode="constant", value=1)
        vk = torch.matmul(v, trans_k)
        out = torch.matmul(vk, q)
        if out.dtype == torch.bfloat16:
            out = out.float()
        out = out[:, :, :-1] / (out[:, :, -1:] + self.eps)

        out = torch.reshape(out, (B, -1, H, W))
        return out

    @torch.autocast(device_type="cuda", enabled=False)
    def relu_quadratic_att(self, qkv: torch.Tensor) -> torch.Tensor:
        B, _, H, W = list(qkv.size())

        qkv = torch.reshape(
            qkv,
            (
                B,
                -1,
                3 * self.dim,
                H * W,
            ),
        )
        q, k, v = (
            qkv[:, :, 0 : self.dim],
            qkv[:, :, self.dim : 2 * self.dim],
            qkv[:, :, 2 * self.dim :],
        )

        q = self.kernel_func(q)
        k = self.kernel_func(k)

        att_map = torch.matmul(k.transpose(-1, -2), q)  # b h n n
        original_dtype = att_map.dtype
        if original_dtype in [torch.float16, torch.bfloat16]:
            att_map = att_map.float()
        att_map = att_map / (torch.sum(att_map, dim=2, keepdim=True) + self.eps)  # b h n n
        att_map = att_map.to(original_dtype)
        out = torch.matmul(v, att_map)  # b h d n

        out = torch.reshape(out, (B, -1, H, W))
        return out

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # generate multi-scale q, k, v
        qkv = self.qkv(x)
        multi_scale_qkv = [qkv]
        for op in self.aggreg:
            multi_scale_qkv.append(op(qkv))
        qkv = torch.cat(multi_scale_qkv, dim=1)

        H, W = list(qkv.size())[-2:]
        if H * W > self.dim:
            out = self.relu_linear_att(qkv).to(qkv.dtype)
        else:
            out = self.relu_quadratic_att(qkv)
        out = self.proj(out)

        return out


if __name__ == '__main__':
    x = torch.randn(4, 512, 7, 7)
    model = LiteMLA(512,512)
    output = model(x)
    print(output.shape)

本文只是对论文中的即插即用模块做了整合,对论文中的一些地方难免有遗漏之处,如果想对这些模块有更详细的了解,还是要去读一下原论文,肯定会有更多收获。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

御宇w

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值