YOLO系列:改进YOLOv8——以添加Gam注意力模块为例

一、Gam注意力源码

import torch.nn as nn
import torch
 
class GAM_Attention(nn.Module):
    def __init__(self, in_channels,c2, rate=4):
        super(GAM_Attention, self).__init__()
 
        self.channel_attention = nn.Sequential(
            nn.Linear(in_channels, int(in_channels / rate)),
            nn.ReLU(inplace=True),
            nn.Linear(int(in_channels / rate), in_channels)
        )
 
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),
            nn.BatchNorm2d(int(in_channels / rate)),
            nn.ReLU(inplace=True),
            nn.Conv2d(int(in_channels / rate), in_channels, kernel_size=7, padding=3),
            nn.BatchNorm2d(in_channels)
        )
 
    def forward(self, x):
        b, c, h, w = x.shape
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
        x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
        x_channel_att = x_att_permute.permute(0, 3, 1, 2).sigmoid()
        x = x * x_channel_att
        x_spatial_att = self.spatial_attention(x).sigmoid()
        out = x * x_spatial_att
 
        return out
 
if __name__ == '__main__':
    x = torch.randn(1, 64, 20, 20)
    b, c, h, w = x.shape
    net = GAM_Attention(in_channels=c)
    y = net(x)
    print(y.size())

二、添加方法

此方法仅适用于新版YOLOv8,旧版YOLOv8添加方法略有不同

1、添加注意力源码

在ultralytics/nn/modules/conv.py文件内添加注意力源码

 2、注册并引用注意力

在ultralytics/nn/modules/__init__.py文件内,按下图标识的地方添加注意力名

第一处:在from .conv import()处最后,添加注意力名称

第二处:在__all__={}处最后,添加注意力名称

 3、调用注意力

在ultralytics/nn/tasks.py文件内,键盘点击CTRL+shift+F打开查找界面,搜索

def parse_model(d, ch, verbose=True):

在该函数下方有一堆的elif m in XXX,在某一个elif下方添加如下代码:

        elif m in {GAM_Attention}:
            c1, c2 = ch[f], args[0]
            if c2 != nc:  # if not output
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            args = [c1, c2, *args[1:]]

4、完成配置

在ultralytics/cfg/models/v8文件下,复制yolov8.yaml,并改成自己的名字,复制对应注意力的代码,这里我以Gam注意力为例(不同注意力的配置代码不同,请读者自行修改)

图中nc代表着你自己数据集标签的数量

5、进行训练

在YOLOv8源文件夹下,新建train.py,

from ultralytics import YOLO
if __name__ == '__main__':
    # 加载模型
    model = YOLO("yolov8-NAMAttention.yaml")  # 从头开始构建新模型
    #model = YOLO("yolov8x.pt")  # 加载预训练模型(推荐用于训练)

    # Use the model
    results = model.train(data="data/detect_plane.yaml", epochs=500, batch=8, workers=1, close_mosaic=0, name='cfg')  # 训练模型
    # results = model.val()  # 在验证集上评估模型性能
    # results = model("https://ultralytics.com/images/bus.jpg")  # 预测图像
    # success = model.export(format="onnx")  # 将模型导出为 ONNX 格式

其中model代表着你刚刚新建立的yaml文件名,也就是模型的名称,results代表着你数据集的配置文件,我的配置文件是上一篇博客讲的计挑赛的数据集配置文件。

最后,用命令行开始训练

python train.py

三、附言

注意力不一定会在所有数据集均有精度或者速度的提升,有些注意力只会在特定数据集有小幅度的数据提升,所以读者需要根据自己数据集的特点进行注意力的选择!

  • 1
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值