Parallel Polarized Self Attention
细粒度的像素级任务(比如语义分割)一直都是计算机视觉中非常重要的任务。不同于分类或者检测,细粒度的像素级任务要求模型在低计算开销下,能够建模高分辨率输入/输出特征的远程依赖关系,进而来估计高度非线性的像素语义。CNN中的注意力机制能够捕获长距离的依赖关系,但是这种方式比较复杂并且是对噪声比较敏感的。
对于这类任务,通常采用的是encoder-decoder的结构,encoder用来降低空间维度、提高通道维度;decoder通常是转置卷积或者上采样,用来提高空间的维度、降低通道的维度。因此连接encoder和decoder的tensor通常在空间维度上比较小,虽然这对于计算和显存的使用比较友好,但是对于像实例分割这样的细粒度像素级任务,这种结构显然会造成性能上的损失。
基于此,作者提出了一个即插即用的模块——极化自注意力机制( Polarized Self-Attention(PSA)),用于解决像素级的回归任务,相比于其他注意力机制,极化自注意力机制主要有两个设计上的亮点:
1)极化滤波( Polarized filtering):在通道和空间维度保持比较高的resolution(在通道上保持C/2的维度,在空间上保持[H,W]的维度 ),这一步能够减少降维度造成的信息损失;
2)增强(Enhancement):组合非线性直接拟合典型细粒度回归的输出分布。
论文地址:https://arxiv.org/abs/2107.00782
代码如下:
import numpy as np
import torch
from torch import nn
from torch.nn import init
class ParallelPolarizedSelfAttention(nn.Module):
def __init__(self, channel=512):
super().__init__()
self.ch_wv=nn.Conv2d(channel,channel//2,kernel_size=(1,1))
self.ch_wq=nn.Conv2d(channel,1,kernel_size=(1,1))
self.softmax_channel=nn.Softmax(1)
self.softmax_spatial=nn.Softmax(-1)
self.ch_wz=nn.Conv2d(channel//2,channel,kernel_size=(1,1))
self.ln=nn.LayerNorm(channel)
self.sigmoid=nn.Sigmoid()
self.sp_wv=nn.Conv2d(channel,channel//2,kernel_size=(1,1))
self.sp_wq=nn.Conv2d(channel,channel//2,kernel_size=(1,1))
self.agp=nn.AdaptiveAvgPool2d((1,1))
def forward(self, x):
b, c, h, w = x.size()
#Channel-only Self-Attention
channel_wv=self.ch_wv(x) #bs,c//2,h,w
channel_wq=self.ch_wq(x) #bs,1,h,w
channel_wv=channel_wv.reshape(b,c//2,-1) #bs,c//2,h*w
channel_wq=channel_wq.reshape(b,-1,1) #bs,h*w,1
channel_wq=self.softmax_channel(channel_wq)
channel_wz=torch.matmul(channel_wv,channel_wq).unsqueeze(-1) #bs,c//2,1,1
channel_weight=self.sigmoid(self.ln(self.ch_wz(channel_wz).reshape(b,c,1).permute(0,2,1))).permute(0,2,1).reshape(b,c,1,1) #bs,c,1,1
channel_out=channel_weight*x
#Spatial-only Self-Attention
spatial_wv=self.sp_wv(x) #bs,c//2,h,w
spatial_wq=self.sp_wq(x) #bs,c//2,h,w
spatial_wq=self.agp(spatial_wq) #bs,c//2,1,1
spatial_wv=spatial_wv.reshape(b,c//2,-1) #bs,c//2,h*w
spatial_wq=spatial_wq.permute(0,2,3,1).reshape(b,1,c//2) #bs,1,c//2
spatial_wq=self.softmax_spatial(spatial_wq)
spatial_wz=torch.matmul(spatial_wq,spatial_wv) #bs,1,h*w
spatial_weight=self.sigmoid(spatial_wz.reshape(b,1,h,w)) #bs,1,h,w
spatial_out=spatial_weight*x
out=spatial_out+channel_out
return out
if __name__ == '__main__':
input=torch.randn(1,512,7,7)
psa = ParallelPolarizedSelfAttention(channel=512)
output=psa(input)
print(output.shape)