RTMDet训练流程(OpenMMLab AI实战营笔记7)

这次课程是使用rtmdet进行目标检测的课程

文章简介

我们先简单介绍一下rtmdet
paper:RTMDet: An Empirical Study of Designing Real-Time Object Detectors
GitHub: RTMDet

在本文中,我们的目标是设计一种高效的实时物体检测器,它超越了 YOLO 系列,并且可以轻松扩展到许多物体识别任务,例如实例分割和旋转物体检测。为了获得更高效的模型架构,我们探索了一种 在主干和颈部具有兼容能力的架构,该架构 由一个 由大核深度卷积组成的基本构建块 构建。我们在动态标签分配中 计算匹配成本时 进一步 引入软标签 以提高准确性。结合更好的训练技术,由此产生的名为 RTMDet 的目标检测器在 NVIDIA 3090 GPU 上以 300+ FPS 的速度在 COCO 上实现了 52.8% 的 AP,优于当前主流的工业检测器。RTMDet 针对各种应用场景实现了 tiny/small/medium/large/extra-large 模型大小的最佳 参数-精度权衡,并在实时实例分割和旋转目标检测方面获得了最新的性能。我们希望实验结果可以为设计用于许多目标识别任务的多功能实时目标检测器提供新的见解。

整体框架流程
这里,我们简单阐述一下RTMDet设计的基本思想,也就是希望设计一个BackBone和Neck部分可以进行融合的目标检测架构,从而让目标检测更轻量化。能突破yolo的速度。

MMDetection

  • MMDetection
    RTMDet的代码已在MMDetection复现并开源,我们这里使用RTMDet进行气球检测。
  • colab
    所有的代码也可以在我的colab中查看
    colab

代码

  • 安装基本工具 (MMDetection,MMYOLO)
%pip install -U "openmim==0.3.7"
!mim install "mmengine==0.7.1"
!mim install "mmcv==2.0.0"

!git clone -b tutorials https://github.com/open-mmlab/mmdetection.git
%cd mmdetection

%pip install -e .
  • 下载数据集
!wget https://download.openmmlab.com/mmyolo/data/balloon_dataset.zip
!unzip -q balloon_dataset.zip -d data
  • 数据集label转coco
    在这里下载的数据集并不是coco格式的,不能直接在MMDetection这个框架下进行训练。所以需要转换操作。
import os.path as osp
import mmengine, mmcv

def convert_balloon_to_coco(ann_file, out_file, image_prefix):
    data_infos = mmengine.load(ann_file)

    annotations = []
    images = []
    obj_count = 0
    for idx, v in enumerate(mmengine.track_iter_progress(data_infos.values())):
        filename = v['filename']
        img_path = osp.join(image_prefix, filename)
        height, width = mmcv.imread(img_path).shape[:2]

        images.append(dict(
            id=idx,
            file_name=filename,
            height=height,
            width=width))

        bboxes = []
        labels = []
        masks = []
        for _, obj in v['regions'].items():
            assert not obj['region_attributes']
            obj = obj['shape_attributes']
            px = obj['all_points_x']
            py = obj['all_points_y']
            poly = [(x + 0.5, y + 0.5) for x, y in zip(px, py)]
            poly = [p for x in poly for p in x]

            x_min, y_min, x_max, y_max = (
                min(px), min(py), max(px), max(py))


            data_anno = dict(
                image_id=idx,
                id=obj_count,
                category_id=0,
                bbox=[x_min, y_min, x_max - x_min, y_max - y_min],
                area=(x_max - x_min) * (y_max - y_min),
                segmentation=[poly],
                iscrowd=0)
            annotations.append(data_anno)
            obj_count += 1

    coco_format_json = dict(
        images=images,
        annotations=annotations,
        categories=[{'id':0, 'name': 'balloon'}])
    mmengine.dump(coco_format_json, out_file)

convert_balloon_to_coco('data/balloon/train/via_region_data.json', 'data/balloon/train/annotation_coco.json', 'data/balloon/train')
convert_balloon_to_coco('data/balloon/val/via_region_data.json', 'data/balloon/val/annotation_coco.json', 'data/balloon/val')

这样,我们就把所有数据转换成coco格式。

  • 生成 Config
    在使用MMDetection或者所有MM进行训练时,我们都是需要创建一个Config文件用于保存参数,然后再进行训练。这样做的好处是。我们可以为每个训练任务创建一个Config文件,而不是每次修改之前的文件。可以理解成编程里,我们拥有一个父类。每次训练任务都生成新的子类进行继承。
# 当前路径位于 mmdetection/tutorials, 配置将写到 mmdetection/tutorials 路径下

config_balloon = """
_base_ = 'configs/rtmdet/rtmdet_tiny_8xb32-300e_coco.py'

data_root = 'data/balloon/'

metainfo = {
    'classes': ('balloon',),
    'palette': [
        (220, 20, 60),
    ]
}
num_classes = 1

max_epochs = 40
train_batch_size_per_gpu = 12
train_num_workers = 4

val_batch_size_per_gpu = 1
val_num_workers = 2

# RTMDet 训练过程分成 2 个 stage, 第二个 stage 会切换数据增强 pipeline
num_epochs_stage2 = 5

base_lr = 12 * 0.004 / (32*8)

# 采用 COCO 预训练权重
load_from = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth'  # noqa

model = dict(
    backbone=dict(frozen_stages=4),
    bbox_head=dict(dict(num_classes=num_classes)))

train_dataloader = dict(
    batch_size=train_batch_size_per_gpu,
    num_workers=train_num_workers,
    pin_memory=False,
    dataset=dict(
        data_root=data_root,
        metainfo=metainfo,
        ann_file='train/annotation_coco.json',
        data_prefix=dict(img='train/')))

val_dataloader = dict(
    batch_size=val_batch_size_per_gpu,
    num_workers=val_num_workers,
    dataset=dict(
        metainfo=metainfo,
        data_root=data_root,
        ann_file='val/annotation_coco.json',
        data_prefix=dict(img='val/')))

test_dataloader = val_dataloader

# 默认的学习率调度器是 warmup 1000, 但是 cat 数据集太小了,需要修改 为 30 iter
param_scheduler = [
    dict(
        type='LinearLR',
        start_factor=1.0e-5,
        by_epoch=False,
        begin=0,
        end=30),
    dict(
        type='CosineAnnealingLR',
        eta_min=base_lr * 0.05,
        begin=max_epochs // 2,  # max_epoch 也改变了
        end=max_epochs,
        T_max=max_epochs // 2,
        by_epoch=True,
        convert_to_iter_based=True),
]
optim_wrapper = dict(optimizer=dict(lr=base_lr))

# 第二 stage 切换 pipeline 的 epoch 时刻也改变了
_base_.custom_hooks[1].switch_epoch = max_epochs - num_epochs_stage2

val_evaluator = dict(ann_file=data_root + 'val/annotation_coco.json')
test_evaluator = val_evaluator

# 一些打印设置修改
default_hooks = dict(
    checkpoint=dict(interval=10, max_keep_ckpts=2, save_best='auto'),  # 同时保存最好性能权重
    logger=dict(type='LoggerHook', interval=5))
train_cfg = dict(max_epochs=max_epochs, val_interval=10)
"""

with open('rtmdet_tiny_1xb12-40e_balloon.py', 'w') as f:
    f.write(config_balloon)
  • 训练
!python tools/train.py rtmdet_tiny_1xb12-40e_balloon.py
  • 测试
!python tools/test.py rtmdet_tiny_1xb12-40e_balloon.py work_dirs/rtmdet_tiny_1xb12-40e_balloon/best_coco/bbox_mAP_epoch_40.pth

这样就完成了训练。

可视化

更多可视化过程请参考我的colab文件中,其中还包含了热力图等可视化。

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用 openAI 人工智能区块链平台工具,数据采集,模型调用

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值