Improved Multiscale Vision Transformers

Improved Multiscale Vision Transformers

MViT v2引入了两个设计来进一步提升MViT v1的性能
(1)、使用可分解的相对位置编码来注入位置信息
(2)、使用残差连接来弥补SA计算中步长的损失

MViT v1中在不同的stage建立不同分辨率的模块。数据的通道数逐渐增多,同时数据的分辨率(对应的序列长度)会逐渐下降。为了在Transformer模块中实现下采样操作,MViT提出了Pooling Attention。

对于任意输入序列,通过Linear得到Q,K,V矩阵。Q,K,V矩阵经过池化处理,在池化之后的基础上进行注意力的计算。其中对K、V矩阵进行池化操作的kernal,stride,padding保持一致,对Q矩阵,残差处理的池化操作的kernal,stride,padding保持一致。
在这里插入图片描述
Pooling Attention可以在每个stage都进行池化,这样可以大大减少Q-K-V计算时的内存成本和计算量。

MViTv1中引入的MSPA池化注意力大大减少SA的计算量,主要会在Q-K-V进行线性映射后在进行一步池化操作,但是在v1中K,V采用的步长更大,Q只有在输出序列发生变化时才进行降采样,这就需要在pooling attention module的计算中加入残差连接来增加信息流动。MViT v2在注意力模块中引入一种新的残差池化连接,表示为以下公式:

在这里插入图片描述
模型会在注意力计算后与pooled Q进行残差连接作为最终的输出。(Q和Z的shape应该相同)
消融实验表明,使用残差连接和对Q进行池化都是很有比较的,一方面可以降低SA的计算复杂福一方面可以提升性能

	# 输入tensor查看数据处理过程 (batch_size,channel,H,W)
 	model = build_model(cfg)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    x = torch.rand((10, 3, 224, 224))
    x = x.to(device)
    model(x)

在这里插入图片描述

#Pooling Attention 
#此代码为上图部分的数据处理流程
    def forward(self, x, hw_shape): # x:[10,3137,96] , hw_shape:[56,56]
    
        B, N, _ = x.shape

        if self.pool_first:
            if self.mode == "conv_unshared":
                fold_dim = 1
            else:
                fold_dim = self.num_heads
            x = x.reshape(B, N, fold_dim, -1).permute(0, 2, 1, 3)
            q = k = v = x
        else:
            assert self.mode != "conv_unshared"

            qkv = (
                self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
            )
            q, k, v = qkv[0], qkv[1], qkv[2] # q.shape = k.shape = v.shape = [10,1,3137,96] 1为num_heads 
			
		# attention_pool为执行pooling操作
		#  self.pool_k: Conv2d(96, 96, kernel_size=(3, 3), stride=(4, 4), padding=(1, 1), groups=96, bias=False)
		#  self.pool_q: Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=96, bias=False)
		#  self.pool_v: Conv2d(96, 96, kernel_size=(3, 3), stride=(4, 4), padding=(1, 1), groups=96, bias=False)
		#  代码中第一层block没有 增加/降低 图片的 channel/分辨率 。
        q, q_shape = attention_pool( # q: [10,1,3137,96] q_shape: [56,56]
            q,
            self.pool_q,
            hw_shape,
            has_cls_embed=self.has_cls_embed,
            norm=self.norm_q if hasattr(self, "norm_q") else None,
        )
        k, k_shape = attention_pool( # k: [10,1,197,96] k_shape: [14,14]
            k,
            self.pool_k,
            hw_shape,
            has_cls_embed=self.has_cls_embed,
            norm=self.norm_k if hasattr(self, "norm_k") else None,
        )
        v, v_shape = attention_pool( # v: [10,1,197,96] v_shape: [14,14]
            v,
            self.pool_v,
            hw_shape,
            has_cls_embed=self.has_cls_embed,
            norm=self.norm_v if hasattr(self, "norm_v") else None,
        )

        if self.pool_first:
            q_N = numpy.prod(q_shape) + 1 if self.has_cls_embed else numpy.prod(q_shape)
            k_N = numpy.prod(k_shape) + 1 if self.has_cls_embed else numpy.prod(k_shape)
            v_N = numpy.prod(v_shape) + 1 if self.has_cls_embed else numpy.prod(v_shape)

            q = q.permute(0, 2, 1, 3).reshape(B, q_N, -1)
            q = self.q(q).reshape(B, q_N, self.num_heads, -1).permute(0, 2, 1, 3)

            v = v.permute(0, 2, 1, 3).reshape(B, v_N, -1)
            v = self.v(v).reshape(B, v_N, self.num_heads, -1).permute(0, 2, 1, 3)

            k = k.permute(0, 2, 1, 3).reshape(B, k_N, -1)
            k = self.k(k).reshape(B, k_N, self.num_heads, -1).permute(0, 2, 1, 3)

        N = q.shape[2]
        attn = (q * self.scale) @ k.transpose(-2, -1) # attn: [10,1,3137,197]
        if self.rel_pos_spatial:
        	#添加相对位置编码(本篇论文的创新点)
            attn = cal_rel_pos_spatial(
                attn,
                q,
                self.has_cls_embed,
                q_shape,
                k_shape,
                self.rel_pos_h,
                self.rel_pos_w,
            )

        attn = attn.softmax(dim=-1)
        x = attn @ v
		# 进行残差连接(本篇论文的创新点)
        if self.residual_pooling:
            if self.has_cls_embed:
                x[:, :, 1:, :] += q[:, :, 1:, :]
            else:
                x = x + q

        x = x.transpose(1, 2).reshape(B, -1, self.dim_out)
        x = self.proj(x)

        return x, q_shape # # x:[10,3137,96] , q_shape:[56,56]

虽然MViT在捕捉token之间关系方面已经展示了优异的性能,但是这种注意力更关注内容而不是空间结构。通过绝对位置编码提供位置信息忽略了图像中很重要的一点,就是特征的平移不变性。在原始MViT中如果两个patch的绝对位置发生变化那么他们之间的关系就会发生变化,即使这两个patch的相对位置并没有发生变化。
为了解决这个问题,论文引入了相对位置编码,即计算两个patch之间的相对位置信息,然后进行位置嵌入。同时为了减少内存和时间开销,论文将两个patch之间的距离沿着时空轴进行分解,分别沿着长,宽,时间来进行计算。这样计算的时间复杂度为O(T+W+H)。

# 添加相对位置编码
def cal_rel_pos_spatial(
    attn, # attn : [10,1,3137,197]
    q, # q : [10,1,3137,96]
    has_cls_embed, # True
    q_shape, #q_shape [56,56]
    k_shape, #k_shape [14,14]
    rel_pos_h, #rel_pos_h Parameter:(111,96) 可学习参数
    rel_pos_w, #rel_pos_w Parameter:(111,96) 可学习参数
):
    """
    Spatial Relative Positional Embeddings.
    """
    sp_idx = 1 if has_cls_embed else 0
    q_h, q_w = q_shape
    k_h, k_w = k_shape

    # Scale up rel pos if shapes for q and k are different.
    q_h_ratio = max(k_h / q_h, 1.0)
    k_h_ratio = max(q_h / k_h, 1.0)
    dist_h = (
        torch.arange(q_h)[:, None] * q_h_ratio - torch.arange(k_h)[None, :] * k_h_ratio
    )
    dist_h += (k_h - 1) * k_h_ratio
    q_w_ratio = max(k_w / q_w, 1.0)
    k_w_ratio = max(q_w / k_w, 1.0)
    dist_w = (
        torch.arange(q_w)[:, None] * q_w_ratio - torch.arange(k_w)[None, :] * k_w_ratio
    )
    dist_w += (k_w - 1) * k_w_ratio
	# dist_w [56,14]
	# dist_h [56,14]
	# 表示相对位置索引
    Rh = rel_pos_h[dist_h.long()]
    Rw = rel_pos_w[dist_w.long()]

    B, n_head, q_N, dim = q.shape

    r_q = q[:, :, sp_idx:].reshape(B, n_head, q_h, q_w, dim)
    rel_h = torch.einsum("byhwc,hkc->byhwk", r_q, Rh)
    rel_w = torch.einsum("byhwc,wkc->byhwk", r_q, Rw)
	# rel_h : [10,1,56,56,14]
	# rel_w : [10,1,56,56,14]
    attn[:, :, sp_idx:, sp_idx:] = (
        attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w)
        + rel_h[:, :, :, :, :, None]
        + rel_w[:, :, :, :, None, :]
    ).view(B, -1, q_h * q_w, k_h * k_w)
# 分别在W,H维度进行位置编码
    return attn

在这里插入图片描述

	# 代码展示了上图完整的数据流通过程
    def forward(self, x, hw_shape): # x_block : [10,3137,96] hw_shape [56,56]
        x_norm = self.norm1(x)
        # self.attn 表示对x进行Linear构造Q,K,V。然后通过池化,MatMul(K),添加相对位置编码,Softmax,MatMul(V),残差连接(Q),Linear
        # x_block : [10,3137,96] hw_shape [56,56] (56*56+1 = 3137)
        x_block, hw_shape_new = self.attn(x_norm, hw_shape)
		
		#数据经过某些block之后dim_out和dim_in 可能不相等
        if self.dim_mul_in_att and self.dim != self.dim_out:
            x = self.proj(x_norm)
            
        # 对初始输入数据进行pooling操作,保证可以和x_block相加 
        # 这里的 kenenl、stride、padding size应该和对Q矩阵进行池化操作的kenenl、stride、padding相同
        x_res, _ = attention_pool(
            x, self.pool_skip, hw_shape, has_cls_embed=self.has_cls_embed
        )
        x = x_res + self.drop_path(x_block)
        x_norm = self.norm2(x)
        x_mlp = self.mlp(x_norm)

        if not self.dim_mul_in_att and self.dim != self.dim_out:
            x = self.proj(x_norm)
        x = x + self.drop_path(x_mlp)
        return x, hw_shape_new

总体流程代码

    def forward(self, x):
        x, bchw = self.patch_embed(x)
        H, W = bchw[-2], bchw[-1]
        B, N, C = x.shape

        if self.cls_embed_on:
            cls_tokens = self.cls_token.expand(B, -1, -1)
            x = torch.cat((cls_tokens, x), dim=1)

        if self.use_abs_pos:
            x = x + self.pos_embed

        thw = [H, W]
        # 进入block块中进行处理
        for blk in self.blocks:
            x, thw = blk(x, thw)

        x = self.norm(x)
		# cls分类器进行分类
        if self.cls_embed_on:
            x = x[:, 0]
        else:
            x = x.mean(1)

        x = self.head(x)
        return x
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值