2024年出了一个新的随插随用的注意力机制ELA(Efficient Local Attention)。
具体结构:
(a)SE模块
(b)CA模块
(c)ELA模块
文中总结了CA模块的不足之处并针对CA的缺点提出了ELA模块。
为了优化 ELA 的性能,同时考虑参数数量(参数为:conv1d中的Kernel_size,groups,GN中的num_groups),文中引入了四种方案:
ELA-Tiny(ELA-T), ELA-Base(ELAB), ELA-Small(ELA-S), ELA-Large(ELA-L)
代码实现:
import torch
import torch.nn as nn
class ELA(nn.Module):
def __init__(self, in_channels, phi):
super(ELA, self).__init__()
'''
ELA-T 和 ELA-B 设计为轻量级,非常适合网络层数较少或轻量级网络的 CNN 架构
ELA-B 和 ELA-S 在具有更深结构的网络上表现最佳
ELA-L 特别适合大型网络。
'''
Kernel_size = {'T': 5, 'B': 7, 'S': 5, 'L': 7}[phi]
groups = {'T': in_channels, 'B': in_channels, 'S': in_channels//8, 'L': in_channels//8}[phi]
num_groups = {'T': 32, 'B': 16, 'S': 16, 'L': 16}[phi]
pad = Kernel_size//2
self.con1 = nn.Conv1d(in_channels, in_channels, kernel_size=Kernel_size, padding=pad, groups=groups, bias=False)
self.GN = nn.GroupNorm(num_groups, in_channels)
self.sigmoid = nn.Sigmoid()
def forward(self, input):
b, c, h, w = input.size()
x_h = torch.mean(input, dim=3, keepdim=True).view(b,c,h)
x_w = torch.mean(input, dim=2, keepdim=True).view(b,c,w)
x_h = self.con1(x_h) # [b,c,h]
x_w = self.con1(x_w) # [b,c,w]
x_h = self.sigmoid(self.GN(x_h)).view(b, c, h, 1) # [b, c, h, 1]
x_w = self.sigmoid(self.GN(x_w)).view(b, c, 1, w) # [b, c, 1, w]
return x_h * x_w * input
if __name__ == "__main__":
# 创建一个形状为 [batch_size, channels, height, width] 的虚拟输入张量
input = torch.randn(2, 256, 40, 40)
ela = ELA(in_channels=256, phi='T')
output = ela(input)
print(output.size())