paper:MViTv2: Improved Multiscale Vision Transformers for Classification and Detection
official implementation:https://github.com/facebookresearch/SlowFast
third-party implementation:https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/mvitv2.py
出发点
传统的视觉识别任务设计架构一直以来都具有挑战性,且通常需要平衡简单性与有效性。近年来,Vision Transformers (ViT)在图像分类任务中展现了较好的性能,但在处理高分辨率的目标检测和时空视频理解任务时仍面临着计算和内存需求的挑战。本文的出发点是改进MViT架构,使其更适合处理这些复杂的视觉任务。
创新点
本文提出了一种改进的多尺度视觉Transformer (MViTv2),通过引入分解相对位置嵌入和残差池化连接,来提升图像分类、目标检测和视频分类任务的性能。MViTv2在ImageNet分类、COCO检测和Kinetics视频识别任务中均表现出色,超越了现有的工作。
- 引入分解的相对位置嵌入:通过引入仅依赖于相对位置距离的相对位置嵌入,解决了MViT在建模空间-时间结构时依赖于“绝对”位置嵌入的问题,从而提高了位移不变性。
- 残差池化连接:在Transformer模块中加入残差池化连接,以补偿池化操作在注意力计算中的影响。
- 提出混合窗口注意力机制:结合池化注意力和窗口注意力机制,提升了计算效率和精度。
方法介绍
Decomposed relative position embedding
虽然MViT可以建模token之间的交互,但关注的是内容而不是结构,时空结构建模仅依赖于绝对位置编码来提供位置信息。这忽略了视觉中的平移不变性,也就是说MViT对两个token的建模会随着它们绝对位置的变化而改变,即便它们的相对位置不变。为了解决这个问题,本文引入了相对位置编码。
我们将两个输入元素 \(i, j\) 之间的相对位置编码为positional embedding \(R_{p(i),p(j)}\in \mathbb{R}^d\),其中 \(p(i),p(j)\) 表示 \(i, j\) 的空间位置。然后将成对的编码embed进注意力模块中
但是 \(R_{p(i),p(j)}\) 的数量是 \(\mathcal{O}(TWH)\) 数量级,计算代价很大,为了降低复杂度,作者将 \(i,j\) 之间距离的计算沿时空轴进行分解
其中 \(R^h,R^w,R^t\) 是沿高度、宽度、时间轴的位置embedding,\(h(i),w(i),t(i)\) 分别表示token \(i\) 竖直、水平 、时间的位置。其中 \(R^t\) 不是必要的,只在视频任务中使用。分解的embedding将学习的embedding数量降低到了 \(\mathcal{O}(T+W+H)\) 规模。
Residual pooling connection
MViT v1在 \(K,V\) 上的步长比 \(Q\) 大,因为 \(Q\) 张量只在输出序列的分辨率发生变化时才进行降采样。这促使作者加上一个对pooled \(Q\) 张量的residual pooling connection以增加信息流动,促进MViT中池化注意力block的训练。
如图2所示,作者在attention block中引入了一个残差池化连接,具体来说就是将pooled query加到输出序列 \(Z\) 中。
实验结果
不同模型variant的配置如下
在ImageNet-1K上的结果如下表所示,可以看到MViT v2取得了比v1更好的结果。
因为采用了和卷积网络一样的金字塔式网络结构,MViT v2可以自然地作为backbone用于目标检测模型,如下是和FPN结合的示例
使用Mask R-CNN和Cascade Mask R-CNN在COCO数据集上的结果如下所示
代码解析
这里以timm中的实现为例,模型选择"mvitv2_small",输入大小为(1, 3, 224, 224)。模型参数配置如下所示
for i in range(num_stages):
if cfg.expand_attn: # True
dim_out = cfg.embed_dim[i] # (96, 192, 384, 768)
else:
dim_out = cfg.embed_dim[min(i + 1, num_stages - 1)]
stage = MultiScaleVitStage(
dim=embed_dim, # 96
dim_out=dim_out,
depth=cfg.depths[i], # (1, 2, 11, 2)
num_heads=cfg.num_heads[i], # (1, 2, 4, 8)
feat_size=feat_size, # (56,56)
mlp_ratio=cfg.mlp_ratio, # 4.0
qkv_bias=cfg.qkv_bias, # True
mode=cfg.mode, # conv
pool_first=cfg.pool_first, # False
expand_attn=cfg.expand_attn, # True
kernel_q=cfg.kernel_qkv, # (3,3)
kernel_kv=cfg.kernel_qkv, # (3,3)
stride_q=cfg.stride_q[i], # ((1, 1), (2, 2), (2, 2), (2, 2))
stride_kv=cfg.stride_kv[i], # ((4, 4), (2, 2), (1, 1), (1, 1))
has_cls_token=cfg.use_cls_token, # False
rel_pos_type=cfg.rel_pos_type, # spatial
residual_pooling=cfg.residual_pooling, # True
norm_layer=norm_layer, # LayerNorm
drop_path=dpr[i], # [[0.0], [0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0]]
)
curr_stride *= max(cfg.stride_q[i])
self.feature_info += [dict(module=f'block.{i}', num_chs=dim_out, reduction=curr_stride)]
embed_dim = dim_out
feat_size = stage.feat_size
self.stages.append(stage)
Attention类的实现如下,经过stem处理后这里第一个stage不用再对空间分辨率降采样了,所以这里self.pool_q的步长为1,而self.pool_k和self.pool_v的步长都是4是为了降低计算了,和输出分辨率无关。
其中154行是本文提出的residual pooling connection,即将pooled Q与attention的最终输出相加。
141行是本文的另一个创新点即引入相对位置编码。
class MultiScaleAttention(nn.Module):
def __init__(
self,
dim,
dim_out,
feat_size,
num_heads=8,
qkv_bias=True,
mode="conv",
kernel_q=(1, 1),
kernel_kv=(1, 1),
stride_q=(1, 1), # (1,1)
stride_kv=(1, 1), # (4,4)
has_cls_token=True,
rel_pos_type='spatial',
residual_pooling=True,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.num_heads = num_heads
self.dim_out = dim_out
self.head_dim = dim_out // num_heads
self.scale = self.head_dim ** -0.5
self.has_cls_token = has_cls_token
padding_q = tuple([int(q // 2) for q in kernel_q])
padding_kv = tuple([int(kv // 2) for kv in kernel_kv])
self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias)
self.proj = nn.Linear(dim_out, dim_out)
# Skip pooling with kernel and stride size of (1, 1, 1).
if prod(kernel_q) == 1 and prod(stride_q) == 1:
kernel_q = None
if prod(kernel_kv) == 1 and prod(stride_kv) == 1:
kernel_kv = None
self.mode = mode
self.unshared = mode == 'conv_unshared'
self.norm_q, self.norm_k, self.norm_v = None, None, None
self.pool_q, self.pool_k, self.pool_v = None, None, None
if mode in ("avg", "max"):
pool_op = nn.MaxPool2d if mode == "max" else nn.AvgPool2d
if kernel_q:
self.pool_q = pool_op(kernel_q, stride_q, padding_q)
if kernel_kv:
self.pool_k = pool_op(kernel_kv, stride_kv, padding_kv)
self.pool_v = pool_op(kernel_kv, stride_kv, padding_kv)
elif mode == "conv" or mode == "conv_unshared":
dim_conv = dim_out // num_heads if mode == "conv" else dim_out
if kernel_q:
self.pool_q = nn.Conv2d(
dim_conv,
dim_conv,
kernel_q,
stride=stride_q,
padding=padding_q,
groups=dim_conv,
bias=False,
)
self.norm_q = norm_layer(dim_conv)
if kernel_kv:
self.pool_k = nn.Conv2d(
dim_conv,
dim_conv,
kernel_kv,
stride=stride_kv,
padding=padding_kv,
groups=dim_conv,
bias=False,
)
self.norm_k = norm_layer(dim_conv)
self.pool_v = nn.Conv2d(
dim_conv,
dim_conv,
kernel_kv,
stride=stride_kv,
padding=padding_kv,
groups=dim_conv,
bias=False,
)
self.norm_v = norm_layer(dim_conv)
else:
raise NotImplementedError(f"Unsupported model {mode}")
# relative pos embedding
self.rel_pos_type = rel_pos_type
if self.rel_pos_type == 'spatial':
assert feat_size[0] == feat_size[1]
size = feat_size[0] # 56
q_size = size // stride_q[1] if len(stride_q) > 0 else size # 56 // 1 = 56
kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size # 56 // 4 = 14
rel_sp_dim = 2 * max(q_size, kv_size) - 1 # 111
self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) # (111,96)
self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim))
trunc_normal_tf_(self.rel_pos_h, std=0.02)
trunc_normal_tf_(self.rel_pos_w, std=0.02)
self.residual_pooling = residual_pooling
def forward(self, x, feat_size: List[int]):
# if self.pool_q and self.pool_k and self.pool_v:
# print(self.pool_q)
# print(self.pool_k)
# print(self.pool_v)
# # Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=96, bias=False)
# # Conv2d(96, 96, kernel_size=(3, 3), stride=(4, 4), padding=(1, 1), groups=96, bias=False)
# # Conv2d(96, 96, kernel_size=(3, 3), stride=(4, 4), padding=(1, 1), groups=96, bias=False)
# exit()
B, N, _ = x.shape # (1,3136,96)
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # (1,3136,288)->(1,3136,3,1,96)->(3,1,1,3136,96)
q, k, v = qkv.unbind(dim=0) # (1,1,3136,96)
if self.pool_q is not None:
q, q_tok = reshape_pre_pool(q, feat_size, self.has_cls_token) # (1,1,3136,96)->(1,96,56,56)
q = self.pool_q(q) # (1,96,56,56)
q, q_size = reshape_post_pool(q, self.num_heads, q_tok) # (1,1,3136,96)
else:
q_size = feat_size
if self.norm_q is not None:
q = self.norm_q(q)
if self.pool_k is not None:
# (1,1,3136,96)
k, k_tok = reshape_pre_pool(k, feat_size, self.has_cls_token) # (1,96,56,56)
k = self.pool_k(k) # (1,96,14,14)
k, k_size = reshape_post_pool(k, self.num_heads, k_tok) # (1,1,196,96)
else:
k_size = feat_size
if self.norm_k is not None:
k = self.norm_k(k)
if self.pool_v is not None:
v, v_tok = reshape_pre_pool(v, feat_size, self.has_cls_token)
v = self.pool_v(v)
v, _ = reshape_post_pool(v, self.num_heads, v_tok) # (1,1,196,96)
if self.norm_v is not None:
v = self.norm_v(v)
attn = (q * self.scale) @ k.transpose(-2, -1) # (1,1,3136,196)
if self.rel_pos_type == 'spatial':
attn = cal_rel_pos_type(
attn,
q,
self.has_cls_token,
q_size,
k_size,
self.rel_pos_h,
self.rel_pos_w,
)
attn = attn.softmax(dim=-1)
x = attn @ v
if self.residual_pooling:
x = x + q
x = x.transpose(1, 2).reshape(B, -1, self.dim_out)
x = self.proj(x)
return x, q_size
cal_rel_pos_type函数的实现如下,这里是求shape为(56, 56)的query和shape为(14, 14)的key上所有点之间的相对位置偏差。其中将位置偏差按h和w方向解耦分开计算。
def cal_rel_pos_type(
attn: torch.Tensor, # (1,1,3136,196)
q: torch.Tensor, # (1,1,3136,96)
has_cls_token: bool, # False
q_size: List[int], # [56,56]
k_size: List[int], # [14,14]
rel_pos_h: torch.Tensor, # (111,96)
rel_pos_w: torch.Tensor, # (111,96)
):
"""
Spatial Relative Positional Embeddings.
"""
sp_idx = 1 if has_cls_token else 0
q_h, q_w = q_size
k_h, k_w = k_size
# Scale up rel pos if shapes for q and k are different.
q_h_ratio = max(k_h / q_h, 1.0) # 1.0
k_h_ratio = max(q_h / k_h, 1.0) # 4.0
dist_h = (
torch.arange(q_h, device=q.device).unsqueeze(-1) * q_h_ratio -
torch.arange(k_h, device=q.device).unsqueeze(0) * k_h_ratio
) # (56,1) - (1,14) -> (56,14)
dist_h += (k_h - 1) * k_h_ratio # (14-1)*4=52, 这里加上后是为了让dish_h的值都>=0吗?
q_w_ratio = max(k_w / q_w, 1.0) # 1.0
k_w_ratio = max(q_w / k_w, 1.0) # 4.0
dist_w = (
torch.arange(q_w, device=q.device).unsqueeze(-1) * q_w_ratio -
torch.arange(k_w, device=q.device).unsqueeze(0) * k_w_ratio
) # (56,1)-(1,14) -> (56,14)
dist_w += (k_w - 1) * k_w_ratio
rel_h = rel_pos_h[dist_h.long()] # (56,14,96)
rel_w = rel_pos_w[dist_w.long()] # (56,14,96)
B, n_head, q_N, dim = q.shape # (1,1,3136,96)
r_q = q[:, :, sp_idx:].reshape(B, n_head, q_h, q_w, dim) # (1,1,56,56,96)
rel_h = torch.einsum("byhwc,hkc->byhwk", r_q, rel_h) # (1,1,56,56,14)
rel_w = torch.einsum("byhwc,wkc->byhwk", r_q, rel_w) # (1,1,56,56,14)
attn[:, :, sp_idx:, sp_idx:] = (
attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w) # (1,1,3136,196)->(1,1,56,56,14,14)
+ rel_h.unsqueeze(-1) # (1,1,56,56,14,1)
+ rel_w.unsqueeze(-2) # (1,1,56,56,1,14)
).view(B, -1, q_h * q_w, k_h * k_w) # (1,1,56,56,14,14)->(1,1,3136,196)
return attn