Shunted Self-Attention via Multi-Scale Token Aggregation

【文章链接】[2111.15193] Shunted Self-Attention via Multi-Scale Token Aggregation (arxiv.org)

现有的Transformer模型采用各种降采样策略来减少特征尺寸和内存消耗。例如,VIT方法在第一层进行16 × 16的下采样投影,并在得到的粗粒度和单尺度特征图上计算自注意力;因此,所产生的特征信息损失不可避免地降低了模型的性能。PVT则致力于在细粒度特征上计算自注意力,并通过合并token的空间缩减来降低成本;然而合并过多的token会导致来自小目标和背景噪声的token混合。这样的行为反过来使得模型在捕获小物体时效果较差。

SSA的多尺度注意力机制是通过将多个注意力头拆分成若干组来实现的。每个组占一个专用的注意力粒度。对于细粒度的组,SSA学习聚合很少的token,并保留更多的局部细节。对于剩下的粗粒度头组,SSA学习聚合大量的token,从而降低计算成本,同时保留捕获大型对象的能力。多粒度组共同学习多粒度信息,使得模型能够对多尺度对象进行有效建模。

1.回忆PVT中的spatialreduction attention (SRA)

SRA是通过KV中的降低token长度来降低计算成本。

(1)token的维度是(HW,C),表示的是HW个token,特征维度是C。在计算Q,K,V的时候,需要token与三个不同的(C,C)大小的可训练参数做矩阵乘法。现在的计算量是3HWC^{^{2}}。得到Q,K,V三者的大小(HW,C)。

Q:(HW,C)*(C,C)----->(HW,C)

K:(HW,C)*(C,C)----->(HW,C)

V:(HW,C)*(C,C)----->(HW,C)

(2)Q与K的转置相乘,可以得到(HW,HW)大小的关系矩阵,计算量是H^{2}W^{2}C

(HW,C)*(C,HW)----->(HW,HW)

(3)关系矩阵与V相乘,计算量是H^{2}W^{2}C

(HW,HW)*(HW,C)----->(HW,C)

以上,如果K和V的大小不是(HW,C),而是(hw,C),那么以上(2)(3)的矩阵乘法就变成

(HW,C)*(C,hw)----->(HW,hw)       计算量HWhwC

(HW,hw)*(hw,C)----->(HW,C)   计算量HWhwC

由此可见,只要hw小于HW,计算量就会减少,以上就是SRA的概念。

PVT中是如何实现的呢?

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):

        .....  
        #计算QKV时,可训练矩阵大小与第一步相同,没有进行空间减少
        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        .....
 
        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            #使用卷积操作,设置步长,来降低特征图的大小
            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm = nn.LayerNorm(dim)
 

2.Shunted Self-Attention(SSA)

如图所示,SSA与PVT中的SRA不同。对于SSA,在同一个自注意力层中,不同的头对应的K,V长度不同,这样可以捕获不同粒度的信息。在实践中,采用具有卷积核大小和步长均为r^{_{i}}的卷积层来实现下采样。公式为:

具体来说,对于由 i 索引的不同头,键 K 和值 V 被下采样到不同的大小。LE(·) 是通过深度卷积对 V 值进行局部增强的部分(这个操作的作用是什么哇)。

当r变大时,K、V中的token被合并,K、V的长度更短,因此计算成本较低,但仍然保留了捕获大对象的能力。相反,当 r 变小时,保留了更多细节,但带来了更多计算成本。将各种 r 集成到一个自注意力层中使其能够捕获多粒度特征。

代码中,只是被分为两种尺度,与文章中table1一致。

            if sr_ratio==8:
                self.sr1 = nn.Conv2d(dim, dim, kernel_size=8, stride=8)
                self.norm1 = nn.LayerNorm(dim)
                self.sr2 = nn.Conv2d(dim, dim, kernel_size=4, stride=4)
                self.norm2 = nn.LayerNorm(dim)
            if sr_ratio==4:
                self.sr1 = nn.Conv2d(dim, dim, kernel_size=4, stride=4)
                self.norm1 = nn.LayerNorm(dim)
                self.sr2 = nn.Conv2d(dim, dim, kernel_size=2, stride=2)
                self.norm2 = nn.LayerNorm(dim)
            if sr_ratio==2:
                self.sr1 = nn.Conv2d(dim, dim, kernel_size=2, stride=2)
                self.norm1 = nn.LayerNorm(dim)
                self.sr2 = nn.Conv2d(dim, dim, kernel_size=1, stride=1)
                self.norm2 = nn.LayerNorm(dim)

3.Detail-specific Feedforward Layers

在传统的前馈层中,全连接层是逐点的,无法学习交叉令牌信息。在这里,我们的目标是通过指定前馈层中的细节来补充局部信息。如图 6 所示,我们通过在前馈层的两个全连接层之间添加数据特定层来补充前馈层中的局部细节:

(注意:1、图片与公式相比,激活层的位置对不上。2、PVT指的是PVT_v2)

前馈层公式:(图片与公式相比,激活层的位置对不上)

代码与公式一致:

    def forward(self, x, H, W):
        x = self.fc1(x)
        x = self.act(x + self.dwconv(x, H, W))
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
校园失物招领系统管理系统按照操作主体分为管理员和用户。管理员的功能包括字典管理、论坛管理、公告信息管理、失物招领管理、失物认领管理、寻物启示管理、寻物认领管理、用户管理、管理员管理。用户的功能等。该系统采用了Mysql数据库,Java语言,Spring Boot框架等技术进行编程实现。 校园失物招领系统管理系统可以提高校园失物招领系统信息管理问题的解决效率,优化校园失物招领系统信息处理流程,保证校园失物招领系统信息数据的安全,它是一个非常可靠,非常安全的应用程序。 ,管理员权限操作的功能包括管理公告,管理校园失物招领系统信息,包括失物招领管理,培训管理,寻物启事管理,薪资管理等,可以管理公告。 失物招领管理界面,管理员在失物招领管理界面中可以对界面中显示,可以对失物招领信息的失物招领状态进行查看,可以添加新的失物招领信息等。寻物启事管理界面,管理员在寻物启事管理界面中查看寻物启事种类信息,寻物启事描述信息,新增寻物启事信息等。公告管理界面,管理员在公告管理界面中新增公告,可以删除公告。公告类型管理界面,管理员在公告类型管理界面查看公告的工作状态,可以对公告的数据进行导出,可以添加新公告的信息,可以编辑公告信息,删除公告信息。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值