(一)使用mmdetection实现自己的第一次训练和推理


github有入门文档,但对于新手还是会遇到各种问题,下面是我第一次使用的详细过程,供大家参考。

一、mmdetection安装

mmdetection安装过程可以参考链接:https://github.com/open-mmlab/mmdetection/blob/master/docs/zh_cn/get_started.md 建议安装最新版本

二、下载数据集

wget https://download.openmmlab.com/mmdetection/data/kitti_tiny.zip
unzip kitti_tiny.zip >  your_dir

三、注册数据集

import os.path as osp
import mmcv
import numpy as np
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.custom import CustomDataset

@DATASETS.register_module()
class KittiTinyDataset(CustomDataset):

    CLASSES = ('Car', 'Pedestrian', 'Cyclist')

    def load_annotations(self, ann_file):
        cat2label = {k: i for i, k in enumerate(self.CLASSES)}
        # load image list from file
        image_list = mmcv.list_from_file(self.ann_file)

        data_infos = []
        # convert annotations to middle format
        for image_id in image_list:
            filename = f'{self.img_prefix}/{image_id}.jpeg'
            image = mmcv.imread(filename)
            height, width = image.shape[:2]

            data_info = dict(filename=f'{image_id}.jpeg', width=width, height=height)

            # load annotations
            label_prefix = self.img_prefix.replace('image_2', 'label_2')
            lines = mmcv.list_from_file(osp.join(label_prefix, f'{image_id}.txt'))

            content = [line.strip().split(' ') for line in lines]
            bbox_names = [x[0] for x in content]
            bboxes = [[float(info) for info in x[4:8]] for x in content]

            gt_bboxes = []
            gt_labels = []
            gt_bboxes_ignore = []
            gt_labels_ignore = []

            # filter 'DontCare'
            for bbox_name, bbox in zip(bbox_names, bboxes):
                if bbox_name in cat2label:
                    gt_labels.append(cat2label[bbox_name])
                    gt_bboxes.append(bbox)
                else:
                    gt_labels_ignore.append(-1)
                    gt_bboxes_ignore.append(bbox)

            data_anno = dict(
                bboxes=np.array(gt_bboxes, dtype=np.float32).reshape(-1, 4),
                labels=np.array(gt_labels, dtype=np.long),
                bboxes_ignore=np.array(gt_bboxes_ignore,
                                       dtype=np.float32).reshape(-1, 4),
                labels_ignore=np.array(gt_labels_ignore, dtype=np.long))

            data_info.update(ann=data_anno)
            data_infos.append(data_info)

        return data_infos

四、更改配置文件

在模型库下载模型文件’mask_rcnn_r50_caffe_fpn_mstrain-poly_3x_coco_bbox_mAP-0.408__segm_mAP-0.37_20200504_163245-42aa3d00.pth’

from mmcv import Config
cfg = Config.fromfile('./configs/faster_rcnn/faster_rcnn_r50_caffe_fpn_mstrain_1x_coco.py')

from mmdet.apis import set_random_seed

# Modify dataset type and path
cfg.dataset_type = 'KittiTinyDataset'
cfg.data_root = 'kitti_tiny/'

cfg.data.test.type = 'KittiTinyDataset'
cfg.data.test.data_root = 'kitti_tiny/'
cfg.data.test.ann_file = 'train.txt'
cfg.data.test.img_prefix = 'training/image_2'

cfg.data.train.type = 'KittiTinyDataset'
cfg.data.train.data_root = 'kitti_tiny/'
cfg.data.train.ann_file = 'train.txt'
cfg.data.train.img_prefix = 'training/image_2'

cfg.data.val.type = 'KittiTinyDataset'
cfg.data.val.data_root = 'kitti_tiny/'
cfg.data.val.ann_file = 'val.txt'
cfg.data.val.img_prefix = 'training/image_2'

# modify num classes of the model in box head
cfg.model.roi_head.bbox_head.num_classes = 3
# We can still use the pre-trained Mask RCNN model though we do not need to
# use the mask branch
cfg.load_from = 'checkpoints/mask_rcnn_r50_caffe_fpn_mstrain-poly_3x_coco_bbox_mAP-0.408__segm_mAP-0.37_20200504_163245-42aa3d00.pth'
cfg.work_dir = './tutorial_exps'

# The original learning rate (LR) is set for 8-GPU training.
# We divide it by 8 since we only use one GPU.
cfg.optimizer.lr = 0.02 / 8
cfg.lr_config.warmup = None
cfg.log_config.interval = 10

# Change the evaluation metric since we use customized dataset.
cfg.evaluation.metric = 'mAP'
# We can set the evaluation interval to reduce the evaluation times
cfg.evaluation.interval = 12
# We can set the checkpoint saving interval to reduce the storage cost
cfg.checkpoint_config.interval = 12

# Set seed thus the results are more reproducible
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)

五、开始你的第一次mmdetection训练

from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.apis import train_detector
# Build dataset
datasets = [build_dataset(cfg.data.train)]
# Build the detector
model = build_detector(
    cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg'))
# Add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
# Create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
train_detector(model, datasets, cfg, distributed=False, validate=True)

六、开始你的第一次mmdetection推理

img = mmcv.imread('kitti_tiny/training/image_2/000068.jpeg')
model.cfg = cfg
from mmdet.apis import inference_detector, init_detector, show_result_pyplot
result = inference_detector(model, img)
show_result_pyplot(model, img, result)

上面这些python代码需要放到一个文件中执行,当我们训练完成之后,可以单独使用下面的代码进行推理。

import os
import mmcv
from mmdet.apis import inference_detector, init_detector, show_result_pyplot
from mmcv import Config
cfg = Config.fromfile('./configs/faster_rcnn/faster_rcnn_r50_caffe_fpn_mstrain_1x_coco.py')
checkpoint_file = 'tutorial_exps/epoch_12.pth'
cfg.model.roi_head.bbox_head.num_classes = 3
# 根据配置文件和 checkpoint 文件构建模型
model = init_detector(cfg, checkpoint_file, device='cuda:0')

file_path="/home/mby/mmdetection/data/kitti_tiny/training/image_2/"
for root, dirs, files in os.walk(file_path, topdown=False):
    for image_id in files:
        filename = file_path + image_id
        image = mmcv.imread(filename)
        result = inference_detector(model, image)
        show_result_pyplot(model, image, result)
  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
使用MMDetection训练自己的目标检测模型的步骤如下: 1. 安装MMDetection。根据的教程,你需要安装MMDetection来开始训练自己的模型。 2. 构建数据集。使用`build_dataset`函数构建你的训练数据集,传入配置文件中的训练数据路径,并将返回的数据集对象存储在一个列表中,如`datasets = [build_dataset(cfg.data.train)]`。 3. 构建检测器模型。使用`build_detector`函数构建你的检测器模型,传入配置文件中的模型配置,并可选地传入训练和测试的配置。将返回的模型对象存储在一个变量中,如`model = build_detector(cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg'))`。 4. 配置工作目录。使用`mmcv.mkdir_or_exist`函数创建你的工作目录,这是保存训练模型和日志文件的地方,如`mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))`。 5. 训练模型。使用`train_detector`函数开始训练你的模型,传入模型、数据集、配置文件以及其他相关参数,如`train_detector(model, datasets, cfg, distributed=False, validate=True)`。 如果你正在使用Jupyter Notebook环境,你可以执行以下代码来使用训练好的模型进行推理和可视化: ```python img = mmcv.imread('kitti_tiny/training/image_2/000068.jpeg') # 读取待检测的图像 model.cfg = cfg # 将配置文件设置给模型 result = inference_detector(model, img) # 进行推理 show_result_pyplot(model, img, result) # 可视化结果 ``` 请根据你的具体需求和数据集进行相应的配置和调整,以训练出适合你的目标检测模型。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* *3* [使用MMDetection训练自己的数据集](https://blog.csdn.net/ECHOSON/article/details/119959863)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT0_1"}}] [.reference_item style="max-width: 100%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值