EffectiveSE注意力
EffectiveSE注意力机制全称是Effective Squeeze and Extraction,即插即用的注意力模块,基于SE(Squeeze and Extraction)改进而来。与SE的区别在于,EffectiveSE注意力机制只有一个全连接层,《CenterMask : Real-Time Anchor-Free Instance Segmentation》的作者注意到SE模块有一个缺点:由于维度的减少导致的通道信息损失。为了避免这种大模型的计算负担,se的2个全连接层需要减少通道维度。特别的,当第一个全连接层使用r减少输入特征通道,将通道数从c变为c/r的时候,第二个全连接层又需要扩张减少的通道数到原始的通道c。在这个过程中,通道维度的减少导致了通道信息的损失。因而,EffectiveSE注意力机制仅仅使用一个通道数为c的全连接层代替了两个全连接层,避免了通道信息的丢失。
论文地址:https://arxiv.org/pdf/1911.06667.pdf
代码如下:
import torch
from torch import nn as nn
from timm.models.layers.create_act import create_act_layer
class EffectiveSEModule(nn.Module):
def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid'):
super(EffectiveSEModule, self).__init__()
self.add_maxpool = add_maxpool
self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
if self.add_maxpool:
# experimental codepath, may remove or change
x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
x_se = self.fc(x_se)
return x * self.gate(x_se)
if __name__ == '__main__':
input=torch.randn(50,512,7,7)
Ese = EffectiveSEModule(512)
output=Ese(input)
print(output.shape)