就是用了一个卷积降了一下k,v 的size
可以理解为将R个点聚合成一个,然后attention的时候Q和聚合成的点的K和V算
import torch from torch import nn class SpatialReductionAttention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): super().__init__() assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." self.dim = dim self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 self.q = nn.Linear(dim, dim, bias=qkv_bias) self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.dropout = nn.Dropout(proj_drop) self.sr_ratio = sr_ratio # 实现上这里等价于一个卷积层 if sr_ratio > 1: self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) self.norm = nn.LayerNorm(dim) def forward(self, x, H, W): B, N, D = x.shape #N=h*w q = self.q(x).reshape(B, N, self.num_heads, D // self.num_heads).permute(0, 2, 1, 3) if self.sr_ratio > 1: x_ = x.permute(0, 2, 1).reshape(B, D, H, W) x_ = self.sr(x_).reshape(B, D, -1).permute(0, 2, 1) # 这里x_.shape = (B, N/R^2, D) x_ = self.norm(x_) #因为做检测分割的图片的分辨率都很大, N也就很大 #这样也是为了不再需要K@V,因为计算量较大 kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, D // self.num_heads).permute(2, 0, 3, 1, 4) else: kv = self.kv(x).reshape(B, -1, 2, self.num_heads, D // self.num_heads).permute(2, 0, 3, 1, 4) k, v = kv[0], kv[1] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, D) x = self.proj(x) x = self.dropout(x) return x x = torch.rand(4, 224*128, 256) attn = SpatialReductionAttention(dim=256, sr_ratio = 2) output = attn(x, H=224, W=128)
PVT的spatial reduction attention(SRA)
于 2022-07-12 13:27:30 首次发布