YOLOv8添加MobileViTv3模块(代码+free)

目录

一、理由

二、方法

(1)导入MobileViTv3模块

(2)在ultralytics/nn/tasks.py的函数parse_model中修改

(3)在yaml配置文件中写入

(4)开始训练,先把其他梯度关闭,保留新加的模块的梯度。

代码已在GitHub上传,链接:yolov8_vit


一、理由

        MobileViTv3是一种为移动设备优化的轻量级视觉Transformer架构,它结合了卷积神经网络(CNN)和视觉Transformer(ViT)的特点,以创建适合移动视觉任务的轻量级模型。

二、方法

(1)导入MobileViTv3模块

在ultralytics/nn创建vit文件夹,文件夹内放MobileViTv3以及需要的包。MobileViTv3模块如下:

import numpy as np
from torch import nn, Tensor
import math
import torch
from torch.nn import functional as F
from typing import Optional, Dict, Tuple, Union, Sequence
from mobilevit_v2_block import MobileViTBlockv2 as MbViTBkV2

class MbViTV3(MbViTBkV2):
    def __init__(
            self,
            in_channels: int,
            attn_unit_dim: int,
            patch_h: Optional[int] = 2,
            patch_w: Optional[int] = 2,
            ffn_multiplier: Optional[Union[Sequence[Union[int, float]], int, float]] = 2.0,
            n_attn_blocks: Optional[int] = 2,
            attn_dropout: Optional[float] = 0.0,
            dropout: Optional[float] = 0.0,
            ffn_dropout: Optional[float] = 0.0,
            conv_ksize: Optional[int] = 3,
            attn_norm_layer: Optional[str] = "layer_norm_2d",
            enable_coreml_compatible_fn: Optional[bool] = False,
    ) -> None:
        super(MbViTV3, self).__init__(in_channels, attn_unit_dim)
        self.enable_coreml_compatible_fn = enable_coreml_compatible_fn
        if self.enable_coreml_compatible_fn:
            # we set persistent to false so that these weights are not part of model's state_dict
            self.register_buffer(
                name="unfolding_weights",
                tensor=self._compute_unfolding_weights(),
                persistent=False,
            )
        cnn_out_dim = attn_unit_dim
        self.conv_proj = nn.Conv2d(2 * cnn_out_dim, in_channels, 1, 1)

    def forward_spatial(self, x: Tensor, *args, **kwargs) -> Tensor:
        x = self.resize_input_if_needed(x)

        fm_conv = self.local_rep(x)

        # convert feature map to patches
        if self.enable_coreml_compatible_fn:
            patches, output_size = self.unfolding_coreml(fm_conv)
        else:
            patches, output_size = self.unfolding_pytorch(fm_conv)

        # learn global representations on all patches
        patches = self.global_rep(patches)

        # [B x Patch x Patches x C] --> [B x C x Patches x Patch]
        if self.enable_coreml_compatible_fn:
            fm = self.folding_coreml(patches=patches, output_size=output_size)
        else:
            fm = self.folding_pytorch(patches=patches, output_size=output_size)

        # MobileViTv3: local+global instead of only global
        fm = self.conv_proj(torch.cat((fm, fm_conv), dim=1))

        # MobileViTv3: skip connection
        fm = fm + x

        return fm


if __name__ == '__main__':
    from thop import profile  ## 导入thop模块

    model = MbViTV3(320, 160, enable_coreml_compatible_fn=False)
    input = torch.randn(1, 320, 44, 84)
    #flops, params = profile(model, inputs=(input,))
    outpus = model.forward_spatial(input)
    print('flops')  ## 打印计算量
    # print('params', params)  ## 打印参数量

(2)在ultralytics/nn/tasks.py的函数parse_model中修改

def parse_model(d, ch, verbose=True):  # model_dict, input_channels(3)
    # Parse a YOLO model.yaml dictionary
    if verbose:
        LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10}  {'module':<45}{'arguments':<30}")
    nc, gd, gw, act = d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
    if act:
        Conv.default_act = eval(act)  # redefine default activation, i.e. Conv.default_act = nn.SiLU()
        if verbose:
            LOGGER.info(f"{colorstr('activation:')} {act}")  # print

    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
    for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, args
    	.......
    	elif m in {MbViTV3}:
            c2 = args[0]
        .......

(3)在yaml配置文件中写入

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 2  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2        320*320*64
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4       160*160*128
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8       80*80*256
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16      40*40*512
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32     20*20*1024
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9              20*20*1024

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 10
  - [[-1, 6], 1, Concat, [1]]                  # 11
  - [-1, 3, C2f, [512]]                        # 12                 40*40*512

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 13
  - [[-1, 4], 1, Concat, [1]]                  # 14
  - [-1, 3, C2f, [256]]                        # 15 (P3/8-small)    44*84*320
  - [-1, 1, MbViTV3, [320, 160]]               # 16

  - [-1, 1, Conv, [256, 3, 2]]                 # 17
  - [[-1, 12], 1, Concat, [1]]                 # 18
  - [-1, 3, C2f, [512]]                        # 19 (P4/16-medium)  40*40*512

  - [-1, 1, Conv, [512, 3, 2]]                # 20
  - [[-1, 9], 1, Concat, [1]]                 # 21
  - [-1, 3, C2f, [1024]]                      # 22 (P5/32-large)  20*20*1024

  - [[16, 19, 22], 1, Detect, [nc]]           # 23

(4)开始训练,先把其他梯度关闭,保留新加的模块的梯度。

import os
from ultralytics import YOLO
import subprocess
from ultralytics.nn.vit.Vit import MbViTV3
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

def add_vit(model):
    for name, param in model.model.named_parameters():
        stand = name[6:8]
        vit_ls = ['16']
        if stand in vit_ls:
            param.requires_grad = True
        else:
            param.requires_grad = False
    for name, param in model.model.named_parameters():
        if param.requires_grad:
            print(name)
    return model

def main():
    # model = YOLO(r'ultralytics/cfg/models/v8/yolov8x.yaml').load('/root/autodl-tmp/yolov8x.pt')
    model = YOLO(r'yolov8x_vit.yaml').load('runs/detect/vit/weights/vit.pt')
    model = add_vit(model)
    model.train(data="data.yaml", imgsz=640, epochs=50, batch=10, device=0, workers=0)
if __name__ == '__main__':
    main()

————————————over————————————

  • 4
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值