阅读论文《Vision Transformer with Super Token Sampling》

一、背景介绍

这篇论文是CVPR2023的一篇论文,主要工作是对于Vision Transformer的自注意力机制进行了魔改。我感觉这篇文章或许对我的工作有帮助,因此,今天精读一下。(侵权删)(非常欢迎来argue,指正我的错误)

论文下载地址:[2211.11167] Vision Transformer with Super Token Sampling (arxiv.org)

代码开源仓库:hhb072/STViT (github.com)

在论文中,魔改前的机制似乎被叫做vanilla。

二、论文的introduction

Transformer被demonstrated在很多任务中表现出色,dominate了NLP,也被demonstrated在cv中表现好,但是,transformer的自注意力机制的computational complexity很高,是token数量的二次方,导致了巨大的计算开销。
有研究表明,ViTs(vision transformer)倾向于捕获具有高冗余度的浅层的局部特征。说人话就是在神经网络的前面几层的全局注意集中在一些相邻的标记(充满红色)上,而忽略了大部分远距离的标记。
可以看出,在shallow-layer中,全局表示做的并不好,所以论文就想改一下注意力机制,来使神经网络的早期阶段获得高效和有效的全局表示。
论文借鉴了superpixels的思想,提出了 super token attention(STA) mechanism。主要分成三步:
1. 首先,应用一种快速采样算法,通过学习标记和超级标记之间的稀疏关联来预测超级标记。
2. 然后,在超级标记空间中进行自注意,以捕获超级标记之间的长期依赖关系。
3. 我们在第一步中使用学习到的关联,将超级标记映射回原始的标记空间。
这个特殊的自注意力机制也被作者结合了自己设计的特殊的block: super token transformer (STT) blocks,由三个key modules组成:Convolutional Position Embedding (CPE), Super Token Attention (STA), and Convolutional Feed-Forward Network (ConvFFN).

三、论文的Related Works

主要是经典论文ViT和Superpixel Algorithms。

四、论文的模型细节

4.1. Overall Architecture

这篇论文的主要工作(STT)如下:

先从convolution stem("stem"通常指的是模型的前几层,这些层通常执行一些初步的特征提取)提取特征开始,因为有几篇论文也用过,发现和最初始的transform的tokenization (vanilla tokenization)相比,效果更好。然后提取的tokens经过四个阶段的STT block进行层次表示提取,以减少tokens数量,最后能输出分类。

这个super token transformer (STT) blocks,由三个key modules组成:Convolutional Position Embedding (CPE), Super Token Attention (STA), and Convolutional Feed-Forward Network (ConvFFN)。是一个线性相连的结构:

X = CPE(X_{in}) + X_{in}

Y = STA(LN(X)) + X

Z = ConvFFN(BN(Y)) + Y

4.2. Super Token Attention

Super Token Attention (STA)部分包含三个步骤:Super Token Sampling (STS), Multi-Head Self-Attention (MHSA), and Token Upsampling (TU)。这里主要讲STS和TU。MHSA我之前文章有讲过。下面是该作者开源的代码,我进行详细解读
class StokenAttention(nn.Module):
    def __init__(self, dim, stoken_size, n_iter=1, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        
        self.n_iter = n_iter
        self.stoken_size = stoken_size
                
        self.scale = dim ** - 0.5
        
        self.unfold = Unfold(3)
        self.fold = Fold(3)
        
        self.stoken_refine = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop)
       
        
    def stoken_forward(self, x):
        '''
           x: (B, C, H, W)
        '''
        B, C, H0, W0 = x.shape
        h, w = self.stoken_size
        
        pad_l = pad_t = 0
        pad_r = (w - W0 % w) % w
        pad_b = (h - H0 % h) % h
        if pad_r > 0 or pad_b > 0:
            x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))
            
        _, _, H, W = x.shape
        
        hh, ww = H//h, W//w
        
        stoken_features = F.adaptive_avg_pool2d(x, (hh, ww)) # (B, C, hh, ww)
        
        pixel_features = x.reshape(B, C, hh, h, ww, w).permute(0, 2, 4, 3, 5, 1).reshape(B, hh*ww, h*w, C)
        
        with torch.no_grad():
            for idx in range(self.n_iter):
                stoken_features = self.unfold(stoken_features) # (B, C*9, hh*ww)
                stoken_features = stoken_features.transpose(1, 2).reshape(B, hh*ww, C, 9)
                affinity_matrix = pixel_features @ stoken_features * self.scale # (B, hh*ww, h*w, 9)
                
                affinity_matrix = affinity_matrix.softmax(-1) # (B, hh*ww, h*w, 9)
               
                affinity_matrix_sum = affinity_matrix.sum(2).transpose(1, 2).reshape(B, 9, hh, ww)
               
                affinity_matrix_sum = self.fold(affinity_matrix_sum)
                if idx < self.n_iter - 1:
                    stoken_features = pixel_features.transpose(-1, -2) @ affinity_matrix # (B, hh*ww, C, 9)
                    
                    stoken_features = self.fold(stoken_features.permute(0, 2, 3, 1).reshape(B*C, 9, hh, ww)).reshape(B, C, hh, ww)            
                    
                    stoken_features = stoken_features/(affinity_matrix_sum + 1e-12) # (B, C, hh, ww)
                    
        
        stoken_features = pixel_features.transpose(-1, -2) @ affinity_matrix # (B, hh*ww, C, 9)
       
        stoken_features = self.fold(stoken_features.permute(0, 2, 3, 1).reshape(B*C, 9, hh, ww)).reshape(B, C, hh, ww)            
        
        stoken_features = stoken_features/(affinity_matrix_sum.detach() + 1e-12) # (B, C, hh, ww)
        
        stoken_features = self.stoken_refine(stoken_features)
        
        
        stoken_features = self.unfold(stoken_features) # (B, C*9, hh*ww)
        stoken_features = stoken_features.transpose(1, 2).reshape(B, hh*ww, C, 9) # (B, hh*ww, C, 9)
       
        pixel_features = stoken_features @ affinity_matrix.transpose(-1, -2) # (B, hh*ww, C, h*w)
       
        pixel_features = pixel_features.reshape(B, hh, ww, C, h, w).permute(0, 3, 1, 4, 2, 5).reshape(B, C, H, W)
                     
        if pad_r > 0 or pad_b > 0:
            pixel_features = pixel_features[:, :, :H0, :W0]
        
        return pixel_features
    
    
    def direct_forward(self, x):
        B, C, H, W = x.shape
        stoken_features = x
        stoken_features = self.stoken_refine(stoken_features)        
        return stoken_features
        
    def forward(self, x):
        if self.stoken_size[0] > 1 or self.stoken_size[1] > 1:
            return self.stoken_forward(x)
        else:
            return self.direct_forward(x)

4.2.1.Super Token Sampling (STS)

这一步是借鉴superpixels的一篇论文,通过稀疏关联学习从视觉标记中采样超标记。在他的代码中,这一步在 stoken_forward 函数中实现,首先,函数通过 F.padF.adaptive_avg_pool2d 将输入图像分解为一组超标记(stoken)。然后,它通过 unfold 函数将每个超标记进一步分解为一组像素特征。这些像素特征然后被用于计算每个超标记的特征。这就是为什么它被叫做super,因为他不像vit和swin是一整块的特征,是一组像素特征

4.2.2.Token Upsampling (TU)

将超标记映射回原始的标记空间。这一步在 stoken_forward 函数的最后部分实现。函数首先通过 fold 函数将处理过的超标记特征重新组合成原始图像的大小。然后,它通过 direct_forward 函数将这些特征映射回原始的像素空间。

4.2.3.复杂性分析

该工作相当于是将计算开销更大的 N × N 注意力矩阵 A(X) 分解为稀疏矩阵和小矩阵的乘法,即稀疏关联Q和m × m 注意力矩阵 A(S)。首先先给出对Global Self-Attention的复杂性的分析:

\Omega (GSA) = 2N^2C + 4NC^2

再对STA的复杂性进行了分析:

\Omega(STA) = \Omega(STS) + \Omega(MHSA) + \Omega (TU) = 19NC + (2m^2C + 4mC^2) + 9NC = 2m^2C + 4mC^2 + 28NC

很明显,这个m是要小于N的,这个复杂性是有所降低的!

5.实验

文章把这个结构应用在了不同的任务上,发现均有不错的提升,得出结论,STA yes!

  • 12
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值