Shunted Self-Attention via Multi-Scale Token Aggregation

【文章链接】[2111.15193] Shunted Self-Attention via Multi-Scale Token Aggregation (arxiv.org)

现有的Transformer模型采用各种降采样策略来减少特征尺寸和内存消耗。例如,VIT方法在第一层进行16 × 16的下采样投影,并在得到的粗粒度和单尺度特征图上计算自注意力;因此,所产生的特征信息损失不可避免地降低了模型的性能。PVT则致力于在细粒度特征上计算自注意力,并通过合并token的空间缩减来降低成本;然而合并过多的token会导致来自小目标和背景噪声的token混合。这样的行为反过来使得模型在捕获小物体时效果较差。

SSA的多尺度注意力机制是通过将多个注意力头拆分成若干组来实现的。每个组占一个专用的注意力粒度。对于细粒度的组,SSA学习聚合很少的token,并保留更多的局部细节。对于剩下的粗粒度头组,SSA学习聚合大量的token,从而降低计算成本,同时保留捕获大型对象的能力。多粒度组共同学习多粒度信息,使得模型能够对多尺度对象进行有效建模。

1.回忆PVT中的spatialreduction attention (SRA)

SRA是通过KV中的降低token长度来降低计算成本。

(1)token的维度是(HW,C),表示的是HW个token,特征维度是C。在计算Q,K,V的时候,需要token与三个不同的(C,C)大小的可训练参数做矩阵乘法。现在的计算量是3HWC^{^{2}}。得到Q,K,V三者的大小(HW,C)。

Q:(HW,C)*(C,C)----->(HW,C)

K:(HW,C)*(C,C)----->(HW,C)

V:(HW,C)*(C,C)----->(HW,C)

(2)Q与K的转置相乘,可以得到(HW,HW)大小的关系矩阵,计算量是H^{2}W^{2}C

(HW,C)*(C,HW)----->(HW,HW)

(3)关系矩阵与V相乘,计算量是H^{2}W^{2}C

(HW,HW)*(HW,C)----->(HW,C)

以上,如果K和V的大小不是(HW,C),而是(hw,C),那么以上(2)(3)的矩阵乘法就变成

(HW,C)*(C,hw)----->(HW,hw)       计算量HWhwC

(HW,hw)*(hw,C)----->(HW,C)   计算量HWhwC

由此可见,只要hw小于HW,计算量就会减少,以上就是SRA的概念。

PVT中是如何实现的呢?

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):

        .....  
        #计算QKV时,可训练矩阵大小与第一步相同,没有进行空间减少
        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        .....
 
        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            #使用卷积操作,设置步长,来降低特征图的大小
            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm = nn.LayerNorm(dim)
 

2.Shunted Self-Attention(SSA)

如图所示,SSA与PVT中的SRA不同。对于SSA,在同一个自注意力层中,不同的头对应的K,V长度不同,这样可以捕获不同粒度的信息。在实践中,采用具有卷积核大小和步长均为r^{_{i}}的卷积层来实现下采样。公式为:

具体来说,对于由 i 索引的不同头,键 K 和值 V 被下采样到不同的大小。LE(·) 是通过深度卷积对 V 值进行局部增强的部分(这个操作的作用是什么哇)。

当r变大时,K、V中的token被合并,K、V的长度更短,因此计算成本较低,但仍然保留了捕获大对象的能力。相反,当 r 变小时,保留了更多细节,但带来了更多计算成本。将各种 r 集成到一个自注意力层中使其能够捕获多粒度特征。

代码中,只是被分为两种尺度,与文章中table1一致。

            if sr_ratio==8:
                self.sr1 = nn.Conv2d(dim, dim, kernel_size=8, stride=8)
                self.norm1 = nn.LayerNorm(dim)
                self.sr2 = nn.Conv2d(dim, dim, kernel_size=4, stride=4)
                self.norm2 = nn.LayerNorm(dim)
            if sr_ratio==4:
                self.sr1 = nn.Conv2d(dim, dim, kernel_size=4, stride=4)
                self.norm1 = nn.LayerNorm(dim)
                self.sr2 = nn.Conv2d(dim, dim, kernel_size=2, stride=2)
                self.norm2 = nn.LayerNorm(dim)
            if sr_ratio==2:
                self.sr1 = nn.Conv2d(dim, dim, kernel_size=2, stride=2)
                self.norm1 = nn.LayerNorm(dim)
                self.sr2 = nn.Conv2d(dim, dim, kernel_size=1, stride=1)
                self.norm2 = nn.LayerNorm(dim)

3.Detail-specific Feedforward Layers

在传统的前馈层中,全连接层是逐点的,无法学习交叉令牌信息。在这里,我们的目标是通过指定前馈层中的细节来补充局部信息。如图 6 所示,我们通过在前馈层的两个全连接层之间添加数据特定层来补充前馈层中的局部细节:

(注意:1、图片与公式相比,激活层的位置对不上。2、PVT指的是PVT_v2)

前馈层公式:(图片与公式相比,激活层的位置对不上)

代码与公式一致:

    def forward(self, x, H, W):
        x = self.fc1(x)
        x = self.act(x + self.dwconv(x, H, W))
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

### Tokens-to-Token机制概述 Tokens-to-Token(T2T)是一种用于提升Transformer模型效率和性能的技术。通过将多个输入标记组合成更少数量的高级表示,可以有效减少计算量并提高处理速度[^1]。 #### 实现方式 在具体实现上,通常采用多尺度令牌聚合方法来完成这一过程: - **分层降维**:首先定义一系列不同大小的感受野窗口,在每个位置提取局部特征形成初始token集合; - **自注意力模块**:接着利用shunted self-attention结构对这些tokens执行交互操作,使得相邻区域间的信息能够相互传递; - **迭代压缩**:最后经过多次重复上述两步流程逐步降低维度直至达到预期目标尺寸为止。 以下是Python伪代码展示如何构建一个多级别的Token聚合网络: ```python class MultiScaleTokenAggregator(nn.Module): def __init__(self, input_dim=768, output_dim=384, num_heads=8): super().__init__() # 定义不同的感受野大小 scales = [(i * 2 + 1) for i in range(3)] self.attentions = nn.ModuleList([ ShuntedSelfAttention(input_dim=input_dim, head_num=num_heads, kernel_size=k) for k in scales]) self.fc = nn.Linear(len(scales)*input_dim, output_dim) def forward(self, x): out = [] for atten_layer in self.attentions: z = atten_layer(x).flatten(-2,-1) out.append(z.unsqueeze(dim=-1)) return F.relu(self.fc(torch.cat(out,dim=-1))) ``` 此架构允许灵活调整参数以适应各种视觉识别任务需求,并已在ImageNet数据集上取得了良好的分类效果[^3]。 #### 应用场景 该技术特别适用于图像理解和自然语言处理领域中的大规模预训练模型优化工作。例如,在计算机视觉方面可以通过这种方式显著加快ResNet、EfficientNet等经典CNN框架下的微调收敛速率;而在NLP方向则有助于BERT、RoBERTa之类基于Transformer的语言模型更好地捕捉长距离依赖关系。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值