【YOLO改进】换遍主干网络之CVPR2024 微软新作StarNet:超强轻量级Backbone(基于MMYOLO)

StarNet

论文链接:[2403.19967] Rewrite the Stars

github仓库:GitHub - ma-xu/Rewrite-the-Stars: [CVPR 2024] Rewrite the Stars

CVPR2024 Rewrite the Stars论文揭示了star operation(元素乘法)在无需加宽网络下,将输入映射到高维非线性特征空间的能力。基于此提出了StarNet,在紧凑的网络结构和较低的能耗下展示了令人印象深刻的性能和低延迟。

优势 (Advantages)

  1. 高维和非线性特征变换 (High-Dimensional and Non-Linear Feature Transformation)

    • StarNet通过星操作(star operation)实现高维和非线性特征空间的映射,而无需增加计算复杂度。与传统的内核技巧(kernel tricks)类似,星操作能够在低维输入中隐式获得高维特征​ (ar5iv)​。
    • 对于YOLO系列网络,这意味着在保持计算效率的同时,能够获得更丰富和表达力更强的特征表示,这对于目标检测任务中的精细特征捕获尤为重要。
  2. 高效网络设计 (Efficient Network Design)

    • StarNet通过星操作实现了高效的特征表示,无需复杂的网络设计和额外的计算开销。其独特的能力在于能够在低维空间中执行计算,但隐式地考虑极高维的特征​ (ar5iv)​。
    • 这使得StarNet可以作为YOLO系列网络的主干,提供高效的计算和更好的特征表示,有助于在资源受限的环境中实现更高的检测性能。
  3. 多层次隐式特征扩展 (Multi-Layer Implicit Feature Expansion)

    • 通过多层星操作,StarNet能够递归地增加隐式特征维度,接近无限维度。对于具有较大宽度和深度的网络,这种特性可以显著增强特征的表达能力​ (ar5iv)​。
    • 对于YOLO系列网络,这意味着可以通过适当的深度和宽度设计,显著提高特征提取的质量,从而提升目标检测的准确性。

解决的问题 (Problems Addressed)

  1. 计算复杂度与性能的平衡 (Balance Between Computational Complexity and Performance)

    • StarNet通过星操作在保持计算复杂度较低的同时,实现了高维特征空间的映射。这解决了传统高效网络设计中计算复杂度与性能之间的权衡问题​ (ar5iv)​。
    • YOLO系列网络需要在实时性和检测精度之间找到平衡,StarNet的高效特性正好契合这一需求。
  2. 特征表示的丰富性 (Richness of Feature Representation)

    • 传统卷积网络在特征表示的高维非线性变换上存在一定局限性,而StarNet通过星操作实现了更丰富的特征表示​ (ar5iv)​。
    • 在目标检测任务中,特别是对于小目标和复杂场景,丰富的特征表示能够显著提升检测效果,使得YOLO系列网络在这些场景中表现更佳。
  3. 简化网络设计 (Simplified Network Design)

    • StarNet通过星操作提供了一种简化网络设计的方法,无需复杂的特征融合和多分支设计就能实现高效的特征表示​ (ar5iv)​。
    • 对于YOLO系列网络,这意味着可以更容易地设计和实现高效的主干网络,降低设计和调试的复杂度。

在MMYOLO中将StarNet替换成yolov5的主干网络

1. 在上文提到的仓库中下载imagenet/starnet.py

2. 修改starnet.py中的forward函数,并且添加out_dices参数使其能够输出不同stage的特征向量

3. 将class StarNet注册并且在__init__()函数中进行修改

4. 修改配置文件,主要是调整YOLOv5 neck和head的输入输出通道数

修改后的starnet.py

"""
Implementation of Prof-of-Concept Network: StarNet.

We make StarNet as simple as possible [to show the key contribution of element-wise multiplication]:
    - like NO layer-scale in network design,
    - and NO EMA during training,
    - which would improve the performance further.

Created by: Xu Ma (Email: ma.xu1@northeastern.edu)
Modified Date: Mar/29/2024
"""
import torch
import torch.nn as nn
from timm.models.layers import DropPath, trunc_normal_
from typing import List, Sequence, Union

# from timm.models.registry import register_model
from mmyolo.registry import MODELS

model_urls = {
    "starnet_s1": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s1.pth.tar",
    "starnet_s2": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s2.pth.tar",
    "starnet_s3": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s3.pth.tar",
    "starnet_s4": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s4.pth.tar",
}


class ConvBN(torch.nn.Sequential):
    def __init__(self, in_planes, out_planes, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, with_bn=True):
        super().__init__()
        self.add_module('conv', torch.nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation, groups))
        if with_bn:
            self.add_module('bn', torch.nn.BatchNorm2d(out_planes))
            torch.nn.init.constant_(self.bn.weight, 1)
            torch.nn.init.constant_(self.bn.bias, 0)


class Block(nn.Module):
    def __init__(self, dim, mlp_ratio=3, drop_path=0.):
        super().__init__()
        self.dwconv = ConvBN(dim, dim, 7, 1, (7 - 1) // 2, groups=dim, with_bn=True)
        self.f1 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False)
        self.f2 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False)
        self.g = ConvBN(mlp_ratio * dim, dim, 1, with_bn=True)
        self.dwconv2 = ConvBN(dim, dim, 7, 1, (7 - 1) // 2, groups=dim, with_bn=False)
        self.act = nn.ReLU6()
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x1, x2 = self.f1(x), self.f2(x)
        x = self.act(x1) * x2
        x = self.dwconv2(self.g(x))
        x = input + self.drop_path(x)
        return x

@MODELS.register_module()
class StarNet(nn.Module):
    def __init__(self, base_dim=32, out_indices: Sequence[int] = (0, 1, 2), depths=[3, 3, 12, 5], mlp_ratio=4,
                 drop_path_rate=0.0, num_classes=1000, **kwargs):
        super().__init__()
        self.num_classes = num_classes
        self.in_channel = 32
        self.out_indices = out_indices
        self.depths = depths
        # stem layer
        self.stem = nn.Sequential(ConvBN(3, self.in_channel, kernel_size=3, stride=2, padding=1), nn.ReLU6())
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth
        # build stages
        self.stages = nn.ModuleList()
        cur = 0
        for i_layer in range(len(depths)):
            embed_dim = base_dim * 2 ** i_layer
            down_sampler = ConvBN(self.in_channel, embed_dim, 3, 2, 1)
            self.in_channel = embed_dim
            blocks = [Block(self.in_channel, mlp_ratio, dpr[cur + i]) for i in range(depths[i_layer])]
            cur += depths[i_layer]
            self.stages.append(nn.Sequential(down_sampler, *blocks))
        # head
        # self.norm = nn.BatchNorm2d(self.in_channel)
        # self.avgpool = nn.AdaptiveAvgPool2d(1)
        # self.head = nn.Linear(self.in_channel, num_classes)
        # self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear or nn.Conv2d):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm or nn.BatchNorm2d):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        x = self.stem(x)
        ##记录stage的输出
        outs = []

        for i in range(len(self.depths)):
            x = self.stages[i](x)
            if i in self.out_indices:
                outs.append(x)

        return tuple(outs)


@MODELS.register_module()
def starnet_s1(pretrained=False, **kwargs):
    model = StarNet(24, (0, 1, 2), [2, 2, 8, 3], **kwargs)
    if pretrained:
        url = model_urls['starnet_s1']
        checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
        model.load_state_dict(checkpoint["state_dict"])
    return model


@MODELS.register_module()
def starnet_s2(pretrained=False, **kwargs):
    model = StarNet(32, (0, 1, 2), [1, 2, 6, 2], **kwargs)
    if pretrained:
        url = model_urls['starnet_s2']
        checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
        model.load_state_dict(checkpoint["state_dict"])
    return model


@MODELS.register_module()
def starnet_s3(pretrained=False, **kwargs):
    model = StarNet(32, (0, 1, 2), [2, 2, 8, 4], **kwargs)
    if pretrained:
        url = model_urls['starnet_s3']
        checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
        model.load_state_dict(checkpoint["state_dict"])
    return model


@MODELS.register_module()
def starnet_s4(pretrained=False, **kwargs):
    model = StarNet(32, (0, 1, 2), [3, 3, 12, 5], **kwargs)
    if pretrained:
        url = model_urls['starnet_s4']
        checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
        model.load_state_dict(checkpoint["state_dict"])
    return model


# very small networks #
@MODELS.register_module()
def starnet_s050(pretrained=False, **kwargs):
    return StarNet(16, (0, 1, 2), [1, 1, 3, 1], 3, **kwargs)


@MODELS.register_module()
def starnet_s100(pretrained=False, **kwargs):
    return StarNet(20, (0, 1, 2), [1, 2, 4, 1], 4, **kwargs)


@MODELS.register_module()
def starnet_s150(pretrained=False, **kwargs):
    return StarNet(24, (0, 1, 2), [1, 2, 4, 2], 3, **kwargs)

if __name__ == '__main__':
    model = StarNet()
    input_tensor = torch.randn(1, 3, 224, 224)
    outputs = model(input_tensor)

修改后的__init__.py

# Copyright (c) OpenMMLab. All rights reserved.
from .base_backbone import BaseBackbone
from .csp_darknet import YOLOv5CSPDarknet, YOLOv8CSPDarknet, YOLOXCSPDarknet
from .csp_resnet import PPYOLOECSPResNet
from .cspnext import CSPNeXt
from .efficient_rep import YOLOv6CSPBep, YOLOv6EfficientRep
from .yolov7_backbone import YOLOv7Backbone
from .starnet import StarNet
__all__ = [
    'YOLOv5CSPDarknet', 'BaseBackbone', 'YOLOv6EfficientRep', 'YOLOv6CSPBep',
    'YOLOXCSPDarknet', 'CSPNeXt', 'YOLOv7Backbone', 'PPYOLOECSPResNet',
    'YOLOv8CSPDarknet','StarNet'
]

修改后的配置文件(以yolov5_s-v61_syncbn_8xb16-300e_coco.py为例子)

_base_ = ['../_base_/default_runtime.py', '../_base_/det_p5_tta.py']

# ========================Frequently modified parameters======================
# -----data related-----
data_root = 'data/coco/'  # Root path of data
# Path of train annotation file
train_ann_file = 'annotations/instances_train2017.json'
train_data_prefix = 'train2017/'  # Prefix of train image path
# Path of val annotation file
val_ann_file = 'annotations/instances_val2017.json'
val_data_prefix = 'val2017/'  # Prefix of val image path

num_classes = 80  # Number of classes for classification
# Batch size of a single GPU during training
train_batch_size_per_gpu = 16
# Worker to pre-fetch data for each single GPU during training
train_num_workers = 8
# persistent_workers must be False if num_workers is 0
persistent_workers = True

# -----model related-----
# Basic size of multi-scale prior box
anchors = [
    [(10, 13), (16, 30), (33, 23)],  # P3/8
    [(30, 61), (62, 45), (59, 119)],  # P4/16
    [(116, 90), (156, 198), (373, 326)]  # P5/32
]

# -----train val related-----
# Base learning rate for optim_wrapper. Corresponding to 8xb16=128 bs
base_lr = 0.01
max_epochs = 300  # Maximum training epochs

model_test_cfg = dict(
    # The config of multi-label for multi-class prediction.
    multi_label=True,
    # The number of boxes before NMS
    nms_pre=30000,
    score_thr=0.001,  # Threshold to filter out boxes.
    nms=dict(type='nms', iou_threshold=0.65),  # NMS type and threshold
    max_per_img=300)  # Max number of detections of each image

# ========================Possible modified parameters========================
# -----data related-----
img_scale = (640, 640)  # width, height
# Dataset type, this will be used to define the dataset
dataset_type = 'YOLOv5CocoDataset'
# Batch size of a single GPU during validation
val_batch_size_per_gpu = 1
# Worker to pre-fetch data for each single GPU during validation
val_num_workers = 2

# Config of batch shapes. Only on val.
# It means not used if batch_shapes_cfg is None.
batch_shapes_cfg = dict(
    type='BatchShapePolicy',
    batch_size=val_batch_size_per_gpu,
    img_size=img_scale[0],
    # The image scale of padding should be divided by pad_size_divisor
    size_divisor=32,
    # Additional paddings for pixel scale
    extra_pad_ratio=0.5)

# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 0.33
# The scaling factor that controls the width of the network structure
widen_factor = 0.5
# Strides of multi-scale prior box
strides = [8, 16, 32]
num_det_layers = 3  # The number of model output scales
norm_cfg = dict(type='BN', momentum=0.03, eps=0.001)  # Normalization config

# -----train val related-----
affine_scale = 0.5  # YOLOv5RandomAffine scaling ratio
loss_cls_weight = 0.5
loss_bbox_weight = 0.05
loss_obj_weight = 1.0
prior_match_thr = 4.  # Priori box matching threshold
# The obj loss weights of the three output layers
obj_level_weights = [4., 1., 0.4]
lr_factor = 0.01  # Learning rate scaling factor
weight_decay = 0.0005
# Save model checkpoint and validation intervals
save_checkpoint_intervals = 10
# The maximum checkpoints to keep.
max_keep_ckpts = 3
# Single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)
'''
starnet_channel,base_dim,depths,mlp_ratio
s1:24,[48, 96, 192],[2, 2, 8, 3],4
s2:32,[64, 128, 256],[1, 2, 6, 2],4
s3:32,[64, 128, 256],[2, 2, 8, 4],4
s4:32,[64, 128, 256],[3, 3, 12, 5],4
starnet_s050:16,[32,64,128],[1, 1, 3, 1],3
starnet_s0100:20,[40, 80, 120],[1, 2, 4, 1],4
starnet_s150:24,[48, 96, 192],[1, 2, 4, 2],3
'''
starnet_channel=[48, 96, 192]
depths=[1, 2, 6, 2]
# ===============================Unmodified in most cases====================
model = dict(
    type='YOLODetector',
    data_preprocessor=dict(
        type='mmdet.DetDataPreprocessor',
        mean=[0., 0., 0.],
        std=[255., 255., 255.],
        bgr_to_rgb=True),
    backbone=dict(
        ##s1
        type='StarNet',
        base_dim=24,
        out_indices=(0,1,2),
        depths=depths,
        mlp_ratio=4,
        num_classes=num_classes,
        # deepen_factor=deepen_factor,
        # widen_factor=widen_factor,
        # norm_cfg=norm_cfg,
        # act_cfg=dict(type='SiLU', inplace=True)

    ),

    neck=dict(
        type='YOLOv5PAFPN',
        deepen_factor=deepen_factor,
        widen_factor=widen_factor,
        in_channels=starnet_channel,
        out_channels=starnet_channel,
        num_csp_blocks=3,
        norm_cfg=norm_cfg,
        act_cfg=dict(type='SiLU', inplace=True)),
    bbox_head=dict(
        type='YOLOv5Head',
        head_module=dict(
            type='YOLOv5HeadModule',
            num_classes=num_classes,
            in_channels=starnet_channel,
            widen_factor=widen_factor,
            featmap_strides=strides,
            num_base_priors=3),
        prior_generator=dict(
            type='mmdet.YOLOAnchorGenerator',
            base_sizes=anchors,
            strides=strides),
        # scaled based on number of detection layers
        loss_cls=dict(
            type='mmdet.CrossEntropyLoss',
            use_sigmoid=True,
            reduction='mean',
            loss_weight=loss_cls_weight *
                        (num_classes / 80 * 3 / num_det_layers)),
        # 修改此处实现IoU损失函数的替换
        loss_bbox=dict(
            type='IoULoss',
            focal=True,
            iou_mode='ciou',
            bbox_format='xywh',
            eps=1e-7,
            reduction='mean',
            loss_weight=loss_bbox_weight * (3 / num_det_layers),
            return_iou=True),
        loss_obj=dict(
            type='mmdet.CrossEntropyLoss',
            use_sigmoid=True,
            reduction='mean',
            loss_weight=loss_obj_weight *
                        ((img_scale[0] / 640) ** 2 * 3 / num_det_layers)),
        prior_match_thr=prior_match_thr,
        obj_level_weights=obj_level_weights),
    test_cfg=model_test_cfg)

albu_train_transforms = [
    dict(type='Blur', p=0.01),
    dict(type='MedianBlur', p=0.01),
    dict(type='ToGray', p=0.01),
    dict(type='CLAHE', p=0.01)
]

pre_transform = [
    dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
    dict(type='LoadAnnotations', with_bbox=True)
]

train_pipeline = [
    *pre_transform,
    dict(
        type='Mosaic',
        img_scale=img_scale,
        pad_val=114.0,
        pre_transform=pre_transform),
    dict(
        type='YOLOv5RandomAffine',
        max_rotate_degree=0.0,
        max_shear_degree=0.0,
        scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
        # img_scale is (width, height)
        border=(-img_scale[0] // 2, -img_scale[1] // 2),
        border_val=(114, 114, 114)),
    dict(
        type='mmdet.Albu',
        transforms=albu_train_transforms,
        bbox_params=dict(
            type='BboxParams',
            format='pascal_voc',
            label_fields=['gt_bboxes_labels', 'gt_ignore_flags']),
        keymap={
            'img': 'image',
            'gt_bboxes': 'bboxes'
        }),
    dict(type='YOLOv5HSVRandomAug'),
    dict(type='mmdet.RandomFlip', prob=0.5),
    dict(
        type='mmdet.PackDetInputs',
        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
                   'flip_direction'))
]

train_dataloader = dict(
    batch_size=train_batch_size_per_gpu,
    num_workers=train_num_workers,
    persistent_workers=persistent_workers,
    pin_memory=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file=train_ann_file,
        data_prefix=dict(img=train_data_prefix),
        filter_cfg=dict(filter_empty_gt=False, min_size=32),
        pipeline=train_pipeline))

test_pipeline = [
    dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
    dict(type='YOLOv5KeepRatioResize', scale=img_scale),
    dict(
        type='LetterResize',
        scale=img_scale,
        allow_scale_up=False,
        pad_val=dict(img=114)),
    dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
    dict(
        type='mmdet.PackDetInputs',
        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
                   'scale_factor', 'pad_param'))
]

val_dataloader = dict(
    batch_size=val_batch_size_per_gpu,
    num_workers=val_num_workers,
    persistent_workers=persistent_workers,
    pin_memory=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        test_mode=True,
        data_prefix=dict(img=val_data_prefix),
        ann_file=val_ann_file,
        pipeline=test_pipeline,
        batch_shapes_cfg=batch_shapes_cfg))

test_dataloader = val_dataloader

param_scheduler = None
optim_wrapper = dict(
    type='OptimWrapper',
    optimizer=dict(
        type='SGD',
        lr=base_lr,
        momentum=0.937,
        weight_decay=weight_decay,
        nesterov=True,
        batch_size_per_gpu=train_batch_size_per_gpu),
    constructor='YOLOv5OptimizerConstructor')

default_hooks = dict(
    param_scheduler=dict(
        type='YOLOv5ParamSchedulerHook',
        scheduler_type='linear',
        lr_factor=lr_factor,
        max_epochs=max_epochs),
    checkpoint=dict(
        type='CheckpointHook',
        interval=save_checkpoint_intervals,
        save_best='auto',
        max_keep_ckpts=max_keep_ckpts))

custom_hooks = [
    dict(
        type='EMAHook',
        ema_type='ExpMomentumEMA',
        momentum=0.0001,
        update_buffers=True,
        strict_load=False,
        priority=49)
]

val_evaluator = dict(
    type='mmdet.CocoMetric',
    proposal_nums=(100, 1, 10),
    ann_file=data_root + val_ann_file,
    metric='bbox')
test_evaluator = val_evaluator

train_cfg = dict(
    type='EpochBasedTrainLoop',
    max_epochs=max_epochs,
    val_interval=save_checkpoint_intervals)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

  • 16
    点赞
  • 32
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
轻量化主干网络YOLO(You Only Look Once)是一种用于目标检测的神经网络模型。与传统的目标检测方法相比,YOLO可以实现实时高效的目标检测。 轻量化主干网络适用于移动设备和嵌入式设备等计算资源有限的场景。为了减少网络模型的参数数量和计算复杂度,轻量化主干网络采用了一系列优化策略。 首先,轻量化主干网络采用了深度可分离卷积层(Depthwise Separable Convolution)。深度可分离卷积层将卷积层分为深度卷积和逐点卷积两个步骤,分别处理通道间的信息和空间上的信息。这种方式有效减少了模型的参数数量和计算复杂度。 其次,轻量化主干网络使用了残差模块(Residual Module)。残差模块通过引入跳跃连接,将输入与输出相加,使得网络模型能够更好地学习残差信息。这种结构可以提升网络的性能,并减少网络的参数数量。 此外,轻量化主干网络还使用了空间金字塔池化模块(Spatial Pyramid Pooling)。空间金字塔池化模块可以从不同尺度上提取特征,具有多尺度感受野,在目标检测任务中起到了关键作用。 总体来说,轻量化主干网络采用了深度可分离卷积、残差模块和空间金字塔池化等技术,以减少网络的参数数量和计算复杂度,同时保持高准确率和实时的目标检测能力。它在移动设备和嵌入式设备等场景中具有较好的应用前景。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值