Mamba-Yolo:基于Mamba架构的yolov8目标检测模型

动机:

基于 CNN 和 Transformer 的模型各有局限性。CNN 在捕获长距离信息上存在局部感受野限制,导致在某些情况下难以有效捕获长距离信息,可能导致分割等任务的结果不佳。另一方面,Transformer 在全局建模方面表现出色,能够有效捕获长距离依赖关系,但自注意力机制在处理图像尺寸较大时的复杂度较高,特别是在处理超高清图像检测以及小目标检测等任务时可能面临挑战。

CNN 主要局限性:

    局部感受野限制:CNN 的卷积操作在每一层只能感知局部区域的特征,难以捕获长距离依赖关系。
    参数共享:CNN 中参数共享的特性可能限制其在处理某些复杂模式和全局信息时的表现。

Transformer 主要局限性:

    自注意力机制复杂度:Transformer 中的自注意力机制在处理大规模图像时需要高计算复杂度和显存消耗。
    缺乏局部信息:Transformer 更注重全局关系,可能在一些需要局部信息的任务中表现不佳。

因此,为了克服CNN和Transformer的局限性,SSMs(如Mamba)通过建立远距离依赖关系并保持线性复杂度,展现出在各种任务中的潜力。本文首次提出了 mamba-Yolov8,这是一种将Mamba结合到Yolov8架构中的方法,旨在展示其在目标检测任务中的潜力。通过结合Mamba的优势,mamba-Yolov8旨在改善长距离信息捕获和全局建模能力,以提高目标检测任务的性能和效果。这种结合可能有助于克服传统CNN和Transformer在某些任务中的局限性,为目标检测等任务带来新的发展和进步。

若有想进行魔改、发文章的小伙伴,可在此基础上进行调整、以适配个人发文章的需求。

下图为打印出的结构

其中ultralytics.nn.Addmodules.mamba.MambaLayer 为mamba结构

核心:VSSblock(上图中的MambaLayer)

mamba-yolov8的核心模块是来自 VMamba 的 VSS 块,如图下图所示。

对于经过层归一化后的输入,模型分为两个分支处理:第一个分支经过线性层和激活函数处理,第二个分支经过线性层、深度可分离卷积和激活函数处理,然后进入2D-Selective-Scan(SS2D)。处理后的特征再次归一化,并与第一个分支的输出进行逐元素乘积合并,随后经过一个线性层混合特征,再与残差连接相加形成VSS块的输出。默认情况下,使用激活函数SiLU。

主要还是在 SS2D 这个新的模块,大家可以参考下下面的示意图。

SS2D模块通过扫描展开操作将输入图像在四个方向上展开成序列,然后通过S6块提取特征,以确保全面扫描信息并捕获多样特征。随后,扫描合并操作对四个方向的序列进行求和合并,将输出图像恢复为输入大小。S6块是基于Mamba模块的进一步发展,在S4基础上引入选择机制,有助于保留相关信息并过滤无关信息。

YoloV8改进步骤

1.在该ultralytics/nn下创建Addmodules文件夹,并在下面新建mamba.py文件

2.在mamba.py文件中写入。(注:全部代码私信博主获取,将博主所给代码文件mamba.py,放置在ultralytics/nn/Addmodules/目录结构下)

    class MambaLayer(nn.Module):
        def __init__(self, dim, d_state=16, d_conv=4, expand=2):
            super().__init__()
            self.dim = dim
            self.norm = nn.LayerNorm(dim)
            self.mamba = Mamba(
                d_model=dim,  # Model dimension d_model
                d_state=d_state,  # SSM state expansion factor
                d_conv=d_conv,  # Local convolution width
                expand=expand,  # Block expansion factor
                bimamba_type="v2",
            )
     
        def forward(self, x):
            B, C = x.shape[:2]
     
     
            assert C == self.dim
            n_tokens = x.shape[2:].numel()
            img_dims = x.shape[2:]
            x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2)
            x_norm = self.norm(x_flat)
     
            # x_norm = x_norm.to('cuda')
     
            x_mamba = self.mamba(x_norm)
     
            out = x_mamba.transpose(-1, -2).reshape(B, C, *img_dims)
            #out = out.to(x.device)
            return out

3.在ultralytics/nn/Addmodules/__init__.py文件中写入

from .mamba import *

如下图(注:全部代码私信博主获取,将博主所给代码文件__init__.py,放置在ultralytics/nn/Addmodules/目录结构下)

4. 在ultralytics/nn/tasks.py中导入MambaLayer

from .Addmodules import *

5.在在ultralytics/nn/tasks.py中加入MambaLayer模块

6.在ultralytics/nn/tasks.py的class DetectionModel(BaseModel)类中进行如下修改

    class DetectionModel(BaseModel):
        """YOLOv8 detection model."""
        def __init__(self, cfg='yolov8n.yaml', ch=3, nc=None, verbose=True):  # model, input channels, number of classes
            """Initialize the YOLOv8 detection model with the given config and parameters."""
            super().__init__()
            self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg)  # cfg dict
     
            # Define model
            ch = self.yaml['ch'] = self.yaml.get('ch', ch)  # input channels
            if nc and nc != self.yaml['nc']:
                LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
                self.yaml['nc'] = nc  # override YAML value
            self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose)  # model, savelist
            self.names = {i: f'{i}' for i in range(self.yaml['nc'])}  # default names dict
            self.inplace = self.yaml.get('inplace', True)
            # Build strides
            m = self.model[-1]  # Detect()
            if isinstance(m, (Detect, Segment, Pose)):
                s = 256  # 2x min stride
                m.inplace = self.inplace
                forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose)) else self.forward(x)
                # -------原始---------
                #m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))])  # forward ,模型是通过一次前向传播的方式进行输入、输出比来知道步长缩放比
                #self.stride = m.stride
                # --------------------
                #--基于mamba的改进
                self.stride=torch.tensor([8., 16., 32.])
                m.stride=self.stride
                #----------------------
                m.bias_init()  # only run once
            else:
                self.stride = torch.Tensor([32])  # default stride for i.e. RTDETR
     
            # Init weights, biases
            initialize_weights(self)
            if verbose:
                self.info()
                LOGGER.info('')

7. 在ultralytics/cfg/models/v8/mamba.yaml中配置网络模型结构文件

  

  # Ultralytics YOLO 🚀, AGPL-3.0 license
    # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
     
    # Parameters
    nc: 2  # number of classes
    scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
      # [depth, width, max_channels]
      n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
      s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
      d: [0.67, 0.50, 768]   #YOLOv8s summary: 295 layers, 11716214 parameters, 11716189 gradients,  36.2 GFLOPs
      m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
      l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
      x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
     
    # YOLOv8.0n backbone
    backbone:
      # [from, repeats, module, args]
      - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2      # 0.  320
      - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4     # 1.  160
      - [-1, 3, MambaLayer, [128]]                # 2.  160
      - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8     # 3.  80
      - [-1, 6, MambaLayer, [256]]                # 4.  80
      - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16    # 5.  40
      - [-1, 6, MambaLayer, [512]]                # 6.  40
      - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32   # 7.  20
      - [-1, 3, MambaLayer, [1024]]               # 8.  20
      - [-1, 1, SPPF, [1024, 5]]  # 9            # 9.  20
    # YOLOv8.0n head
    head:
      - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
      - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
      - [-1, 3, C2f, [512]]  # 12
     
      - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
      - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
      - [-1, 3, C2f, [256]]  # 15 (P3/8-small)
     
      - [-1, 1, Conv, [256, 3, 2]]
      - [[-1, 12], 1, Concat, [1]]  # cat head P4
      - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)
     
      - [-1, 1, Conv, [512, 3, 2]]
      - [[-1, 9], 1, Concat, [1]]  # cat head P5
      - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)
     
      - [[15, 18, 21], 1, Detect, [nc]]  # Detect(P3, P4, P5)

8.撰写训练train.py文件开启训练

   

from ultralytics import YOLO
     
    model = YOLO("mamba.yaml")
     
    model.train(data='datasets.yaml',epochs=300,device="0",batch=40,imgsz=640,amp=False)

  • 29
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 22
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 22
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值