paper:EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
1、Multi-Scale Linear Attention
现有模型的局限性存在以下短处:计算成本高: 现有的高分辨率密集预测模型往往依赖于复杂的模型结构,例如 softmax 注意力机制、大卷积核等,这会导致计算成本高昂,难以在硬件设备上部署。性能提升有限: 一些轻量级的模型虽然计算成本较低,但性能提升有限,难以满足实际应用的需求。为了解决现有高分辨率密集预测模型在效率和性能之间的权衡问题。这篇论文提出一种新的多尺度线性注意力(Multi-Scale Linear Attention)。与以往的高分辨率稠密预测模型依赖于繁重的softmax注意力、硬件效率低的大核卷积或复杂的拓扑结构来获得良好性能不同,多尺度线性注意力仅需轻量级和硬件效率高的操作即可实现全局感受野和多尺度学习。
MSLA 的核心思想:全局感受野: 通过 ReLU 线性注意力机制,MSLA能够有效地聚合来自全局的信息,从而获得全局感受野,这对于高分辨率密集预测任务至关重要。多层次学习: 通过对 Q/K/V 向量进行小卷积核的聚合,MSLA 模块能够生成多尺度向量,从而实现多层次学习,捕获不同尺度的特征信息。
对于一个输入 X,MSLA 的实现原理:
-
输入特征图投影:输入特征图经过一个线性投影层,分别投影到 Q (Query), K (Key), V (Value) 向量。
-
多尺度特征生成:对 Q/K/V 向量进行分组操作,每组包含多个向量。对每个分组中的向量进行小卷积核的深度可分离卷积 (DWConv),生成多尺度特征图。
-
ReLU 线性注意力:对多尺度向量进行 ReLU 线性注意力,提取全局特征信息。与 softmax 注意力机制相比,ReLU 线性注意力机制的计算复杂度更低,且不需要进行 softmax 操作,从而提高了硬件效率。
-
特征融合:最后将注意力机制输出与原始特征图进行融合,得到最终的输出特征图。输出特征图包含了不同尺度的全局特征信息,可以用于后续的任务,例如语义分割、超分辨率等。
Multi-Scale Linear Attention 结构图:
2、EfficientViT
基于 MSLA,论文提出一种新架构 EfficientViT,EfficientViT 是一种高效的高分辨率视觉模型,用于密集预测任务。EfficientViT 遵循标准的骨干网络设计,包含输入 Stem 和四个阶段:
- Input Stem:最开始的 Input Stem 由 卷积+DSConv组成。
- Stage:一般的 stage 位于第一, 二阶段,是由 MBConv 组成。
- EfficientViT Module:EfficientViT Module 被插入到第三,四阶段,是 EfficientViT 的基本构建块,由 MSLA 模块和一个 FFN+DWConv 层组成。MSLA 是 EfficientViT 的核心模块,负责提取全局特征信息和多层次特征信息。而 FFN+DWConv 层负责捕获局部特征信息。
- Head:最后,Head 用于处理 Backbone 的输出特征图,并生成最终的预测结果。Head 由多个 MBConv 块和输出层组成,例如预测层和上采样层。
EfficientViT 结构图:
3、代码实现
import math
import torch
import torch.nn as nn
from typing import Tuple
from functools import partial
from inspect import signature
import torch.nn.functional as F
def build_kwargs_from_config(config, target_func):
valid_keys = list(signature(target_func).parameters)
kwargs = {}
for key in config:
if key in valid_keys:
kwargs[key] = config[key]
return kwargs
class LayerNorm2d(nn.LayerNorm):
def forward(self, x):
out = x - torch.mean(x, dim=1, keepdim=True)
out = out / torch.sqrt(torch.square(out).mean(dim=1, keepdim=True) + self.eps)
if self.elementwise_affine:
out = out * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
return out
REGISTERED_ACT_DICT = {
"relu": nn.ReLU,
"relu6": nn.ReLU6,
"hswish": nn.Hardswish,
"silu": nn.SiLU,
"gelu": partial(nn.GELU, approximate="tanh"),
}
REGISTERED_NORM_DICT = {
"bn2d": nn.BatchNorm2d,
"ln": nn.LayerNorm,
"ln2d": LayerNorm2d,
}
def build_act(name, **kwargs):
if name in REGISTERED_ACT_DICT:
act_cls = REGISTERED_ACT_DICT[name]
args = build_kwargs_from_config(kwargs, act_cls)
return act_cls(**args)
else:
return None
def build_act(name, **kwargs):
if name in REGISTERED_ACT_DICT:
act_cls = REGISTERED_ACT_DICT[name]
args = build_kwargs_from_config(kwargs, act_cls)
return act_cls(**args)
else:
return None
def get_same_padding(kernel_size):
if isinstance(kernel_size, tuple):
return tuple([get_same_padding(ks) for ks in kernel_size])
else:
assert kernel_size % 2 > 0, "kernel size should be odd number"
return kernel_size // 2
def val2tuple(x, min_len=1, idx_repeat=0):
x = val2list(x)
# repeat elements if necessary
if len(x) > 0:
x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
return tuple(x)
def val2list(x, repeat_time=1):
if isinstance(x, (list, tuple)):
return list(x)
return [x for _ in range(repeat_time)]
def build_norm(name="bn2d", num_features=None, **kwargs):
if name in ["ln", "ln2d", "trms2d"]:
kwargs["normalized_shape"] = num_features
else:
kwargs["num_features"] = num_features
if name in REGISTERED_NORM_DICT:
norm_cls = REGISTERED_NORM_DICT[name]
args = build_kwargs_from_config(kwargs, norm_cls)
return norm_cls(**args)
else:
return None
class ConvLayer(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3,
stride=1,
dilation=1,
groups=1,
use_bias=False,
dropout=0,
norm="bn2d",
act_func="relu",
):
super(ConvLayer, self).__init__()
padding = get_same_padding(kernel_size)
padding *= dilation
self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=(kernel_size, kernel_size),
stride=(stride, stride),
padding=padding,
dilation=(dilation, dilation),
groups=groups,
bias=use_bias,
)
self.norm = build_norm(norm, num_features=out_channels)
self.act = build_act(act_func)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.dropout is not None:
x = self.dropout(x)
x = self.conv(x)
if self.norm:
x = self.norm(x)
if self.act:
x = self.act(x)
return x
class LiteMLA(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
heads_ratio: float = 1.0,
dim=8,
use_bias=False,
norm=(None, "bn2d"),
act_func=(None, None),
kernel_func="relu",
scales: Tuple[int, ...] = (5,),
eps=1.0e-15,
):
super(LiteMLA, self).__init__()
self.eps = eps
heads = int(in_channels // dim * heads_ratio)
total_dim = heads * dim
use_bias = val2tuple(use_bias, 2)
norm = val2tuple(norm, 2)
act_func = val2tuple(act_func, 2)
self.dim = dim
self.qkv = ConvLayer(
in_channels,
3 * total_dim,
1,
use_bias=use_bias[0],
norm=norm[0],
act_func=act_func[0],
)
self.aggreg = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(
3 * total_dim,
3 * total_dim,
scale,
padding=get_same_padding(scale),
groups=3 * total_dim,
bias=use_bias[0],
),
nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]),
)
for scale in scales
]
)
self.kernel_func = build_act(kernel_func, inplace=False)
self.proj = ConvLayer(
total_dim * (1 + len(scales)),
out_channels,
1,
use_bias=use_bias[1],
norm=norm[1],
act_func=act_func[1],
)
@torch.autocast(device_type="cuda", enabled=False)
def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor:
B, _, H, W = list(qkv.size())
if qkv.dtype == torch.float16:
qkv = qkv.float()
qkv = torch.reshape(
qkv,
(
B,
-1,
3 * self.dim,
H * W,
),
)
q, k, v = (
qkv[:, :, 0 : self.dim],
qkv[:, :, self.dim : 2 * self.dim],
qkv[:, :, 2 * self.dim :],
)
# lightweight linear attention
q = self.kernel_func(q)
k = self.kernel_func(k)
# linear matmul
trans_k = k.transpose(-1, -2)
v = F.pad(v, (0, 0, 0, 1), mode="constant", value=1)
vk = torch.matmul(v, trans_k)
out = torch.matmul(vk, q)
if out.dtype == torch.bfloat16:
out = out.float()
out = out[:, :, :-1] / (out[:, :, -1:] + self.eps)
out = torch.reshape(out, (B, -1, H, W))
return out
@torch.autocast(device_type="cuda", enabled=False)
def relu_quadratic_att(self, qkv: torch.Tensor) -> torch.Tensor:
B, _, H, W = list(qkv.size())
qkv = torch.reshape(
qkv,
(
B,
-1,
3 * self.dim,
H * W,
),
)
q, k, v = (
qkv[:, :, 0 : self.dim],
qkv[:, :, self.dim : 2 * self.dim],
qkv[:, :, 2 * self.dim :],
)
q = self.kernel_func(q)
k = self.kernel_func(k)
att_map = torch.matmul(k.transpose(-1, -2), q) # b h n n
original_dtype = att_map.dtype
if original_dtype in [torch.float16, torch.bfloat16]:
att_map = att_map.float()
att_map = att_map / (torch.sum(att_map, dim=2, keepdim=True) + self.eps) # b h n n
att_map = att_map.to(original_dtype)
out = torch.matmul(v, att_map) # b h d n
out = torch.reshape(out, (B, -1, H, W))
return out
def forward(self, x: torch.Tensor) -> torch.Tensor:
# generate multi-scale q, k, v
qkv = self.qkv(x)
multi_scale_qkv = [qkv]
for op in self.aggreg:
multi_scale_qkv.append(op(qkv))
qkv = torch.cat(multi_scale_qkv, dim=1)
H, W = list(qkv.size())[-2:]
if H * W > self.dim:
out = self.relu_linear_att(qkv).to(qkv.dtype)
else:
out = self.relu_quadratic_att(qkv)
out = self.proj(out)
return out
if __name__ == '__main__':
x = torch.randn(4, 512, 7, 7)
model = LiteMLA(512,512)
output = model(x)
print(output.shape)
本文只是对论文中的即插即用模块做了整合,对论文中的一些地方难免有遗漏之处,如果想对这些模块有更详细的了解,还是要去读一下原论文,肯定会有更多收获。