【MMDetection系列 - 2】MMDetection使用自定义数据集训练、测试已有模型


前言

本文介绍如何用自定义数据集来对MMDetection中给定的模型进行训练及测试。

本文采用的模型是:deformable_detr
数据集为: cat 数据集


一、数据集准备

在终端输入以下指令来下载并解压cat数据集:

rm -rf cat_dataset*
wget https://download.openmmlab.com/mmyolo/data/cat_dataset.zip
unzip cat_dataset.zip -d cat_dataset && rm cat_dataset.zip 

文件目录结构为:

mmdetection
├── mmdet
├── tools
├── configs
├── cat_dataset
│ ├── annotations (其中包括3个json文件)
│ ├── images(包含144张图片)
│ ├── labels(图片的json文件)
│ ├── labelsclass_with_id.txt

二、编写配置文件

1.创建配置文件

在mmdetection根目录下创建一个文件夹try_demo,在文件夹下建立一个my_demo_cat.py文件,文件内容如下

# my_demo_cat: 使用deformable-detr & cat数据集

_base_ = [
    '../configs/deformable_detr/deformable-detr-refine-twostage_r50_16xb2-50e_coco.py'
]

model = dict(
    backbone=dict(
        init_cfg=None)		# 不直接从官网下载预训练模型,使用我自己下载好的预训练模型
)


data_root = 'cat_dataset/' #数据集路径前缀


metainfo = {
    # 类别名,注意 classes 需要是一个 tuple,因此即使是单类,后面也需要加逗号。
    'classes': ('cat',),  # 类别
    'palette': [
        (220, 20, 60),
    ]
}

# 数据集中的类别数
num_classes = 1

model = dict(
    # 考虑到数据集太小,且训练时间很短,我们把 backbone 完全固定
    # 用户自己的数据集可能需要解冻 backbone
    backbone=dict(frozen_stages=4),
    # 修改head中的num_classes,以匹配数据集中的类别数目 (默认的num_classes在configs/rtmdet/rtmdet_l_8xb32-300e_coco.py里)
    bbox_head=dict(dict(num_classes=num_classes)))


# 数据集不同,dataset 输入参数也不一样
train_dataloader = dict(  #训练dataloader 配置
    pin_memory=False,
    dataset=dict( # 训练数据集的配置
        data_root=data_root, # 数据的根路径
        metainfo=metainfo,
        ann_file='annotations/trainval.json', # 标注文件路径
        data_prefix=dict(img='images/'))) # 图片路径前缀

val_dataloader = dict(  # 验证 dataloader 配置
    dataset=dict(
        metainfo=metainfo,
        data_root=data_root,
        ann_file='annotations/test.json',
        data_prefix=dict(img='images/')))

test_dataloader = val_dataloader



# 修改评价指标相关配置
val_evaluator = dict(ann_file=data_root + 'annotations/test.json')
test_evaluator = val_evaluator

load_from = 'checkpoints/deformable-detr-refine-twostage_r50_16xb2-50e_coco_20221021_184714-acc8a5ff.pth'	# 自己下载的预训练模型路径

default_hooks = dict(
    checkpoint=dict(interval=10, max_keep_ckpts=2, save_best='auto'),  # 同时保存最好性能权重
    logger=dict(type='LoggerHook', interval=5))
train_cfg = dict(max_epochs=20, val_interval=10)

需要注意的是:其中预训练权重deformable-detr-refine-twostage_r50_16xb2-50e_coco_20221021_184714-acc8a5ff.pth需自己下载并放在mmdetection/checkpoints/路径下,下载命令为:

mim download mmdet --config deformable-detr-refine-twostage_r50_16xb2-50e_coco  --dest .

2.模型训练

在终端使用以下命令进行模型训练:

python tools/train.py try_demo/my_demo_cat.py 

即根据配置文件的信息调用train.py进行模型训练。

运行完成后,会在当前路径下生成 work_dirs/my_demo_cat 路径,内部存放了 log 和权重文件。

2.模型测试

在终端使用以下命令进行测试:

python tools/test.py try_demo/my_demo_cat.py work_dirs/my_demo_cat/best_coco_bbox_mAP_epoch_20.pth  --show-dir results

设置 --show-dir 可以将测试图片的真实值和预测值保存下来。

运行后,可以在work_dirs/my_demo_cat/当前时间戳/results/下看到生成的图片。
在这里插入图片描述展示其中一张图片,可以看到一张图由真实图片+预测图片构成。
在这里插入图片描述

3.单张图片推理

对单张图片进行推理,可以直接使用mmdetection/demo/image_demo.py脚本。(与第一篇中介绍的运行mmdetection给的demo图片一样的步骤)

在终端输入:

python demo/image_demo.py \
     cat_dataset/images/IMG_20220906_143153.jpg \
     try_demo/my_demo_cat.py \
     --weights work_dirs/my_demo_cat/best_coco_bbox_mAP_epoch_20.pth --show

可以看到单张图片的推理结果:
在这里插入图片描述

4.评价指标可视化

可以通过在终端输入以下指令,来可视化评价指标

首先需要确保安装了seaborn库

python tools/analysis_tools/analyze_logs.py plot_curve work_dirs/my_demo_cat/当前时间戳/vis_data/当前时间戳.json --keys loss loss_cls

找到你的json文件,–keys后面的参数也可自行指定(根据json中有的指标来,这里我展示了loss和loss_cls),得到的可视化结果如下:
在这里插入图片描述

总结

以上就是如何使用MMDetection中的模型与自定义的数据集进行训练的代码,其中模型和数据集都可以根据自己的需要来进行替换,以上只是一个示例,展示了一下基本流程。

参考文献:
[1] MMDetection官方文档之“在标准数据集上训练预定义的模型”
[2] MMDetection实战–基于RTMDet在Cats和10类饮料数据集上的目标检测与MMYoLo Grad-Based CAM可视化
[3] 使用mmdetection训练自己的coco数据集(免费分享自制数据集文件)
[4] mmdetection训练自己的COCO数据集

  • 16
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值