Improved Multiscale Vision Transformers
MViT v2引入了两个设计来进一步提升MViT v1的性能
(1)、使用可分解的相对位置编码来注入位置信息
(2)、使用残差连接来弥补SA计算中步长的损失
MViT v1中在不同的stage建立不同分辨率的模块。数据的通道数逐渐增多,同时数据的分辨率(对应的序列长度)会逐渐下降。为了在Transformer模块中实现下采样操作,MViT提出了Pooling Attention。
对于任意输入序列,通过Linear得到Q,K,V矩阵。Q,K,V矩阵经过池化处理,在池化之后的基础上进行注意力的计算。其中对K、V矩阵进行池化操作的kernal,stride,padding保持一致,对Q矩阵,残差处理的池化操作的kernal,stride,padding保持一致。
Pooling Attention可以在每个stage都进行池化,这样可以大大减少Q-K-V计算时的内存成本和计算量。
MViTv1中引入的MSPA池化注意力大大减少SA的计算量,主要会在Q-K-V进行线性映射后在进行一步池化操作,但是在v1中K,V采用的步长更大,Q只有在输出序列发生变化时才进行降采样,这就需要在pooling attention module的计算中加入残差连接来增加信息流动。MViT v2在注意力模块中引入一种新的残差池化连接,表示为以下公式:
模型会在注意力计算后与pooled Q进行残差连接作为最终的输出。(Q和Z的shape应该相同)
消融实验表明,使用残差连接和对Q进行池化都是很有比较的,一方面可以降低SA的计算复杂福一方面可以提升性能
# 输入tensor查看数据处理过程 (batch_size,channel,H,W)
model = build_model(cfg)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = torch.rand((10, 3, 224, 224))
x = x.to(device)
model(x)
#Pooling Attention
#此代码为上图部分的数据处理流程
def forward(self, x, hw_shape): # x:[10,3137,96] , hw_shape:[56,56]
B, N, _ = x.shape
if self.pool_first:
if self.mode == "conv_unshared":
fold_dim = 1
else:
fold_dim = self.num_heads
x = x.reshape(B, N, fold_dim, -1).permute(0, 2, 1, 3)
q = k = v = x
else:
assert self.mode != "conv_unshared"
qkv = (
self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
)
q, k, v = qkv[0], qkv[1], qkv[2] # q.shape = k.shape = v.shape = [10,1,3137,96] 1为num_heads
# attention_pool为执行pooling操作
# self.pool_k: Conv2d(96, 96, kernel_size=(3, 3), stride=(4, 4), padding=(1, 1), groups=96, bias=False)
# self.pool_q: Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=96, bias=False)
# self.pool_v: Conv2d(96, 96, kernel_size=(3, 3), stride=(4, 4), padding=(1, 1), groups=96, bias=False)
# 代码中第一层block没有 增加/降低 图片的 channel/分辨率 。
q, q_shape = attention_pool( # q: [10,1,3137,96] q_shape: [56,56]
q,
self.pool_q,
hw_shape,
has_cls_embed=self.has_cls_embed,
norm=self.norm_q if hasattr(self, "norm_q") else None,
)
k, k_shape = attention_pool( # k: [10,1,197,96] k_shape: [14,14]
k,
self.pool_k,
hw_shape,
has_cls_embed=self.has_cls_embed,
norm=self.norm_k if hasattr(self, "norm_k") else None,
)
v, v_shape = attention_pool( # v: [10,1,197,96] v_shape: [14,14]
v,
self.pool_v,
hw_shape,
has_cls_embed=self.has_cls_embed,
norm=self.norm_v if hasattr(self, "norm_v") else None,
)
if self.pool_first:
q_N = numpy.prod(q_shape) + 1 if self.has_cls_embed else numpy.prod(q_shape)
k_N = numpy.prod(k_shape) + 1 if self.has_cls_embed else numpy.prod(k_shape)
v_N = numpy.prod(v_shape) + 1 if self.has_cls_embed else numpy.prod(v_shape)
q = q.permute(0, 2, 1, 3).reshape(B, q_N, -1)
q = self.q(q).reshape(B, q_N, self.num_heads, -1).permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3).reshape(B, v_N, -1)
v = self.v(v).reshape(B, v_N, self.num_heads, -1).permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3).reshape(B, k_N, -1)
k = self.k(k).reshape(B, k_N, self.num_heads, -1).permute(0, 2, 1, 3)
N = q.shape[2]
attn = (q * self.scale) @ k.transpose(-2, -1) # attn: [10,1,3137,197]
if self.rel_pos_spatial:
#添加相对位置编码(本篇论文的创新点)
attn = cal_rel_pos_spatial(
attn,
q,
self.has_cls_embed,
q_shape,
k_shape,
self.rel_pos_h,
self.rel_pos_w,
)
attn = attn.softmax(dim=-1)
x = attn @ v
# 进行残差连接(本篇论文的创新点)
if self.residual_pooling:
if self.has_cls_embed:
x[:, :, 1:, :] += q[:, :, 1:, :]
else:
x = x + q
x = x.transpose(1, 2).reshape(B, -1, self.dim_out)
x = self.proj(x)
return x, q_shape # # x:[10,3137,96] , q_shape:[56,56]
虽然MViT在捕捉token之间关系方面已经展示了优异的性能,但是这种注意力更关注内容而不是空间结构。通过绝对位置编码提供位置信息忽略了图像中很重要的一点,就是特征的平移不变性。在原始MViT中如果两个patch的绝对位置发生变化那么他们之间的关系就会发生变化,即使这两个patch的相对位置并没有发生变化。
为了解决这个问题,论文引入了相对位置编码,即计算两个patch之间的相对位置信息,然后进行位置嵌入。同时为了减少内存和时间开销,论文将两个patch之间的距离沿着时空轴进行分解,分别沿着长,宽,时间来进行计算。这样计算的时间复杂度为O(T+W+H)。
# 添加相对位置编码
def cal_rel_pos_spatial(
attn, # attn : [10,1,3137,197]
q, # q : [10,1,3137,96]
has_cls_embed, # True
q_shape, #q_shape [56,56]
k_shape, #k_shape [14,14]
rel_pos_h, #rel_pos_h Parameter:(111,96) 可学习参数
rel_pos_w, #rel_pos_w Parameter:(111,96) 可学习参数
):
"""
Spatial Relative Positional Embeddings.
"""
sp_idx = 1 if has_cls_embed else 0
q_h, q_w = q_shape
k_h, k_w = k_shape
# Scale up rel pos if shapes for q and k are different.
q_h_ratio = max(k_h / q_h, 1.0)
k_h_ratio = max(q_h / k_h, 1.0)
dist_h = (
torch.arange(q_h)[:, None] * q_h_ratio - torch.arange(k_h)[None, :] * k_h_ratio
)
dist_h += (k_h - 1) * k_h_ratio
q_w_ratio = max(k_w / q_w, 1.0)
k_w_ratio = max(q_w / k_w, 1.0)
dist_w = (
torch.arange(q_w)[:, None] * q_w_ratio - torch.arange(k_w)[None, :] * k_w_ratio
)
dist_w += (k_w - 1) * k_w_ratio
# dist_w [56,14]
# dist_h [56,14]
# 表示相对位置索引
Rh = rel_pos_h[dist_h.long()]
Rw = rel_pos_w[dist_w.long()]
B, n_head, q_N, dim = q.shape
r_q = q[:, :, sp_idx:].reshape(B, n_head, q_h, q_w, dim)
rel_h = torch.einsum("byhwc,hkc->byhwk", r_q, Rh)
rel_w = torch.einsum("byhwc,wkc->byhwk", r_q, Rw)
# rel_h : [10,1,56,56,14]
# rel_w : [10,1,56,56,14]
attn[:, :, sp_idx:, sp_idx:] = (
attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w)
+ rel_h[:, :, :, :, :, None]
+ rel_w[:, :, :, :, None, :]
).view(B, -1, q_h * q_w, k_h * k_w)
# 分别在W,H维度进行位置编码
return attn
# 代码展示了上图完整的数据流通过程
def forward(self, x, hw_shape): # x_block : [10,3137,96] hw_shape [56,56]
x_norm = self.norm1(x)
# self.attn 表示对x进行Linear构造Q,K,V。然后通过池化,MatMul(K),添加相对位置编码,Softmax,MatMul(V),残差连接(Q),Linear
# x_block : [10,3137,96] hw_shape [56,56] (56*56+1 = 3137)
x_block, hw_shape_new = self.attn(x_norm, hw_shape)
#数据经过某些block之后dim_out和dim_in 可能不相等
if self.dim_mul_in_att and self.dim != self.dim_out:
x = self.proj(x_norm)
# 对初始输入数据进行pooling操作,保证可以和x_block相加
# 这里的 kenenl、stride、padding size应该和对Q矩阵进行池化操作的kenenl、stride、padding相同
x_res, _ = attention_pool(
x, self.pool_skip, hw_shape, has_cls_embed=self.has_cls_embed
)
x = x_res + self.drop_path(x_block)
x_norm = self.norm2(x)
x_mlp = self.mlp(x_norm)
if not self.dim_mul_in_att and self.dim != self.dim_out:
x = self.proj(x_norm)
x = x + self.drop_path(x_mlp)
return x, hw_shape_new
总体流程代码
def forward(self, x):
x, bchw = self.patch_embed(x)
H, W = bchw[-2], bchw[-1]
B, N, C = x.shape
if self.cls_embed_on:
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
if self.use_abs_pos:
x = x + self.pos_embed
thw = [H, W]
# 进入block块中进行处理
for blk in self.blocks:
x, thw = blk(x, thw)
x = self.norm(x)
# cls分类器进行分类
if self.cls_embed_on:
x = x[:, 0]
else:
x = x.mean(1)
x = self.head(x)
return x