1.ECA介绍
摘要:最近,通道注意力机制已被证明在提高深度卷积神经网络(CNN)性能方面具有巨大潜力。 然而,大多数现有方法致力于开发更复杂的注意力模块以实现更好的性能,这不可避免地增加了模型的复杂性。 为了克服性能和复杂性权衡的悖论,本文提出了一种高效通道注意(ECA)模块,该模块仅涉及少量参数,同时带来了明显的性能增益。 通过剖析 SENet 中的通道注意力模块,我们凭经验证明避免降维对于学习通道注意力非常重要,适当的跨通道交互可以保持性能,同时显着降低模型复杂性。 因此,我们提出了一种无需降维的局部跨通道交互策略,可以通过一维卷积有效地实现。 此外,我们开发了一种自适应选择一维卷积核大小的方法,确定局部跨通道交互的覆盖范围。 所提出的 ECA 模块高效且有效,例如,我们的模块针对 ResNet50 主干网的参数和计算量分别为 80 vs. 24.37M 和 4.7e-4 GFLOPs vs. 3.86 GFLOPs,性能提升超过 2% 就Top-1准确率而言。 我们以 ResNets 和 MobileNetV2 为骨干,在图像分类、对象检测和实例分割方面广泛评估了我们的 ECA 模块。 实验结果表明,我们的模块效率更高,同时性能优于同类模块。
官方论文地址:https://arxiv.org/pdf/1910.03151
官方代码地址:https://github.com/search?q=ECA&type=repositories
简单介绍: ECA(Efficient Channel Attention)注意力机制的原理可精炼概括为:通过避免在通道注意力模块中引入降维操作,采用了一种创新的局部跨通道交互策略。这种方法巧妙地利用1D卷积来实现高效的通道注意力计算,从而在维持出色性能的同时,大幅减少了模型的复杂性。ECA模块通过自适应选择卷积核大小,精确界定了局部跨通道交互的覆盖范围,实现了参数的高效利用和计算成本的显著降低。在ResNets和MobileNetV2等主流神经网络结构中,ECA模块以极少的参数和计算资源,显著提升了模型的性能,相比其他注意力模块展现出更高的效率和优越性。
ECA模块结构图:
2.核心代码
import torch
from torch import nn
from torch.nn.parameter import Parameter
class ECA(nn.Module):
def __init__(self, channel, k_size=3):
super(ECA, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# feature descriptor on the global spatial information
y = self.avg_pool(x)
# Two different branches of ECA module
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
# Multi-scale information fusion
y = self.sigmoid(y)
return x * y.expand_as(x)
3.YOLOv11中添加ECA方式
3.1 在ultralytics/nn下新建Extramodule
3.2 在Extramodule里创建ECA
在ECA.py文件里添加给出的ECA代码
添加完ECA代码后,在ultralytics/nn/Extramodule/__init__.py文件中引用
3.3 在tasks.py里引用
在ultralytics/nn/tasks.py文件里引用Extramodule
在tasks.py找到parse_model(ctrl+f可以直接搜索parse_model位置)
添加如下代码:
elif m in {ECA}:
c2 = ch[f]
args = [c2, *args]
4.新建一个yolo11ECA.yaml文件
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
# YOLO11n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 2, C3k2, [256, False, 0.25]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 2, C3k2, [512, False, 0.25]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 2, C3k2, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 2, C3k2, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
- [-1, 2, C2PSA, [1024]] # 10
# YOLO11n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 2, C3k2, [512, False]] # 13
- [-1, 1, ECA, []]
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
- [-1, 1, ECA, []]
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 13], 1, Concat, [1]] # cat head P4
- [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
- [-1, 1, ECA, []]
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 10], 1, Concat, [1]] # cat head P5
- [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
- [-1, 1, ECA, []]
- [[17, 21, 26], 1, Detect, [nc]] # Detect(P3, P4, P5)
大家根据自己的数据集实际情况,修改nc大小。
5.模型训练
import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO
if __name__ == '__main__':
model = YOLO(r'D:\yolo\yolov11\ultralytics-main\datasets\yolo11ECA.yaml')
model.train(data=r'D:\yolo\yolov11\ultralytics-main\datasets\data.yaml',
cache=False,
imgsz=640,
epochs=100,
single_cls=False, # 是否是单类别检测
batch=8,
close_mosaic=10,
workers=0,
device='0',
optimizer='SGD',
amp=True,
project='runs/train',
name='exp',
)
模型结构打印,成功运行:
6.本文总结
到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的,后期我会根据各种前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~