一、Gam注意力源码
import torch.nn as nn
import torch
class GAM_Attention(nn.Module):
def __init__(self, in_channels,c2, rate=4):
super(GAM_Attention, self).__init__()
self.channel_attention = nn.Sequential(
nn.Linear(in_channels, int(in_channels / rate)),
nn.ReLU(inplace=True),
nn.Linear(int(in_channels / rate), in_channels)
)
self.spatial_attention = nn.Sequential(
nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),
nn.BatchNorm2d(int(in_channels / rate)),
nn.ReLU(inplace=True),
nn.Conv2d(int(in_channels / rate), in_channels, kernel_size=7, padding=3),
nn.BatchNorm2d(in_channels)
)
def forward(self, x):
b, c, h, w = x.shape
x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
x_channel_att = x_att_permute.permute(0, 3, 1, 2).sigmoid()
x = x * x_channel_att
x_spatial_att = self.spatial_attention(x).sigmoid()
out = x * x_spatial_att
return out
if __name__ == '__main__':
x = torch.randn(1, 64, 20, 20)
b, c, h, w = x.shape
net = GAM_Attention(in_channels=c)
y = net(x)
print(y.size())
二、添加方法
此方法仅适用于新版YOLOv8,旧版YOLOv8添加方法略有不同
1、添加注意力源码
在ultralytics/nn/modules/conv.py文件内添加注意力源码
2、注册并引用注意力
在ultralytics/nn/modules/__init__.py文件内,按下图标识的地方添加注意力名
第一处:在from .conv import()处最后,添加注意力名称
第二处:在__all__={}处最后,添加注意力名称
3、调用注意力
在ultralytics/nn/tasks.py文件内,键盘点击CTRL+shift+F打开查找界面,搜索
def parse_model(d, ch, verbose=True):
在该函数下方有一堆的elif m in XXX,在某一个elif下方添加如下代码:
elif m in {GAM_Attention}:
c1, c2 = ch[f], args[0]
if c2 != nc: # if not output
c2 = make_divisible(min(c2, max_channels) * width, 8)
args = [c1, c2, *args[1:]]
4、完成配置
在ultralytics/cfg/models/v8文件下,复制yolov8.yaml,并改成自己的名字,复制对应注意力的代码,这里我以Gam注意力为例(不同注意力的配置代码不同,请读者自行修改)
图中nc代表着你自己数据集标签的数量
5、进行训练
在YOLOv8源文件夹下,新建train.py,
from ultralytics import YOLO
if __name__ == '__main__':
# 加载模型
model = YOLO("yolov8-NAMAttention.yaml") # 从头开始构建新模型
#model = YOLO("yolov8x.pt") # 加载预训练模型(推荐用于训练)
# Use the model
results = model.train(data="data/detect_plane.yaml", epochs=500, batch=8, workers=1, close_mosaic=0, name='cfg') # 训练模型
# results = model.val() # 在验证集上评估模型性能
# results = model("https://ultralytics.com/images/bus.jpg") # 预测图像
# success = model.export(format="onnx") # 将模型导出为 ONNX 格式
其中model代表着你刚刚新建立的yaml文件名,也就是模型的名称,results代表着你数据集的配置文件,我的配置文件是上一篇博客讲的计挑赛的数据集配置文件。
最后,用命令行开始训练
python train.py
三、附言
注意力不一定会在所有数据集均有精度或者速度的提升,有些注意力只会在特定数据集有小幅度的数据提升,所以读者需要根据自己数据集的特点进行注意力的选择!