爆改YOLOv8 | yolov8添加GAM注意力机制

1,本文介绍

GAM(Global Attention Mechanism)旨在改进传统注意力机制的不足,特别是在通道和空间维度上的信息保留问题。它通过顺序的通道-空间注意力机制来解决这些问题。以下是GAM的关键设计和实现细节:

  1. 通道注意力子模块

    • 3D排列:使用3D排列来在三个维度上保留信息,这种方法有助于捕捉更多维度的特征。
    • 两层MLP:通过一个两层的多层感知机(MLP)增强跨维度的通道-空间依赖性,提升了模型对复杂特征的学习能力。
  2. 空间注意力子模块

    • 两个卷积层:采用两个卷积层融合空间信息,增强空间特征的学习,而不是使用最大池化操作,避免了可能导致信息损失的操作。
    • 分组卷积与通道混洗:通过分组卷积和通道混洗,GAM在ResNet50中避免了显著的参数增加,这有助于减少计算开销和内存占用。
  3. 性能提升

    • 在不同网络架构上的应用:GAM在多种神经网络架构上都展示了稳定的性能提升,尤其在ResNet18上,GAM在参数更少的情况下展现了比ABN(Adaptive Bottleneck Network)更好的性能和效率。

GAM通过这些设计增强了对全局信息的捕捉能力,并在保持高效性的同时,显著提高了模型的表现。

以下为GAM模型结构图

关于GAM的详细介绍可以看论文:https://arxiv.org/pdf/2112.05561v1.pdf

本文将讲解如何将GAM融合进yolov8

话不多说,上代码!

2,将GAM融合进YOLOv8

2.1 步骤一

首先找到如下的目录'ultralytics/nn/modules',然后在这个目录下创建一个attention.py文件,文件名字可以根据你自己的习惯起,然后将GAM的核心代码复制进去。

# gam核心代码
import torch
import torch.nn as nn
 
'''
https://arxiv.org/abs/2112.05561
'''
__all__ = (
    "GAM",
)
class GAM(nn.Module):
    def __init__(self, in_channels, rate=4):
        super().__init__()
        out_channels = in_channels
        in_channels = int(in_channels)
        out_channels = int(out_channels)
        inchannel_rate = int(in_channels/rate)
 
 
        self.linear1 = nn.Linear(in_channels, inchannel_rate)
        self.relu = nn.ReLU(inplace=True)
        self.linear2 = nn.Linear(inchannel_rate, in_channels)
        
 
        self.conv1=nn.Conv2d(in_channels, inchannel_rate,kernel_size=7,padding=3,padding_mode='replicate')
 
        self.conv2=nn.Conv2d(inchannel_rate, out_channels,kernel_size=7,padding=3,padding_mode='replicate')
 
        self.norm1 = nn.BatchNorm2d(inchannel_rate)
        self.norm2 = nn.BatchNorm2d(out_channels)
        self.sigmoid = nn.Sigmoid()
 
    def forward(self,x):
        b, c, h, w = x.shape
        # B,C,H,W ==> B,H*W,C
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
        
        # B,H*W,C ==> B,H,W,C
        x_att_permute = self.linear2(self.relu(self.linear1(x_permute))).view(b, h, w, c)
 
        # B,H,W,C ==> B,C,H,W
        x_channel_att = x_att_permute.permute(0, 3, 1, 2)
 
        x = x * x_channel_att
 
        x_spatial_att = self.relu(self.norm1(self.conv1(x)))
        x_spatial_att = self.sigmoid(self.norm2(self.conv2(x_spatial_att)))
        
        out = x * x_spatial_att
 
        return out
 
if __name__ == '__main__':
    img = torch.rand(1,64,32,48)
    b, c, h, w = img.shape
    net = GAM(in_channels=c, out_channels=c)
    output = net(img)
    print(output.shape)

2.2 步骤二

首先找到如下的目录'ultralytics/nn/modules',然后在这个目录下找到init文件,在init中添加如下代码.

from .attention import (
    GAM,
)

同时在init.py中的如下位置添加GAM

2.3 步骤三

在task.py中导入GAM

 2.3 步骤四

在task.py中添加如下代码.

到此注册成功,复制后面的yaml文件直接运行即可

yaml文件

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
 
# Parameters
nc: 80  # number of classes
scales:  # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
 
# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]   # 8
  - [-1, 1, GAM, []]  # 9
  - [-1, 1, SPPF, [1024, 5]]  # 10
 
# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 13
 
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 16 (P3/8-small)
 
  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 19 (P4/16-medium)
 
  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 22 (P5/32-large)
 
  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

# 关于GAM添加的位置还可以放在颈部,针对不同数据集位置不同,效果不同

不知不觉已经看完了哦,动动小手留个点赞吧--_--

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值