PVT的spatial reduction attention(SRA)

就是用了一个卷积降了一下k,v 的size 

可以理解为将R个点聚合成一个,然后attention的时候Q和聚合成的点的K和V算

import torch
from torch import nn
 
class SpatialReductionAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
 
        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
 
        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(proj_drop)
 
        self.sr_ratio = sr_ratio
        # 实现上这里等价于一个卷积层
        if sr_ratio > 1:
            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm = nn.LayerNorm(dim)
 
    def forward(self, x, H, W):
        B, N, D = x.shape  #N=h*w
        q = self.q(x).reshape(B, N, self.num_heads, D // self.num_heads).permute(0, 2, 1, 3)
 
        if self.sr_ratio > 1:
            x_ = x.permute(0, 2, 1).reshape(B, D, H, W)
            x_ = self.sr(x_).reshape(B, D, -1).permute(0, 2, 1) # 这里x_.shape = (B, N/R^2, D)
            x_ = self.norm(x_)
            #因为做检测分割的图片的分辨率都很大, N也就很大
            #这样也是为了不再需要K@V,因为计算量较大
            kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, D // self.num_heads).permute(2, 0, 3, 1, 4)
        else:
            kv = self.kv(x).reshape(B, -1, 2, self.num_heads, D // self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]
 
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
 
        x = (attn @ v).transpose(1, 2).reshape(B, N, D)
        x = self.proj(x)
        x = self.dropout(x)
 
        return x

x = torch.rand(4, 224*128, 256)
attn = SpatialReductionAttention(dim=256, sr_ratio = 2)
output = attn(x, H=224, W=128)

  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值