github有入门文档,但对于新手还是会遇到各种问题,下面是我第一次使用的详细过程,供大家参考。
一、mmdetection安装
mmdetection安装过程可以参考链接:https://github.com/open-mmlab/mmdetection/blob/master/docs/zh_cn/get_started.md 建议安装最新版本
二、下载数据集
wget https://download.openmmlab.com/mmdetection/data/kitti_tiny.zip
unzip kitti_tiny.zip > your_dir
三、注册数据集
import os.path as osp
import mmcv
import numpy as np
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.custom import CustomDataset
@DATASETS.register_module()
class KittiTinyDataset(CustomDataset):
CLASSES = ('Car', 'Pedestrian', 'Cyclist')
def load_annotations(self, ann_file):
cat2label = {k: i for i, k in enumerate(self.CLASSES)}
# load image list from file
image_list = mmcv.list_from_file(self.ann_file)
data_infos = []
# convert annotations to middle format
for image_id in image_list:
filename = f'{self.img_prefix}/{image_id}.jpeg'
image = mmcv.imread(filename)
height, width = image.shape[:2]
data_info = dict(filename=f'{image_id}.jpeg', width=width, height=height)
# load annotations
label_prefix = self.img_prefix.replace('image_2', 'label_2')
lines = mmcv.list_from_file(osp.join(label_prefix, f'{image_id}.txt'))
content = [line.strip().split(' ') for line in lines]
bbox_names = [x[0] for x in content]
bboxes = [[float(info) for info in x[4:8]] for x in content]
gt_bboxes = []
gt_labels = []
gt_bboxes_ignore = []
gt_labels_ignore = []
# filter 'DontCare'
for bbox_name, bbox in zip(bbox_names, bboxes):
if bbox_name in cat2label:
gt_labels.append(cat2label[bbox_name])
gt_bboxes.append(bbox)
else:
gt_labels_ignore.append(-1)
gt_bboxes_ignore.append(bbox)
data_anno = dict(
bboxes=np.array(gt_bboxes, dtype=np.float32).reshape(-1, 4),
labels=np.array(gt_labels, dtype=np.long),
bboxes_ignore=np.array(gt_bboxes_ignore,
dtype=np.float32).reshape(-1, 4),
labels_ignore=np.array(gt_labels_ignore, dtype=np.long))
data_info.update(ann=data_anno)
data_infos.append(data_info)
return data_infos
四、更改配置文件
在模型库下载模型文件’mask_rcnn_r50_caffe_fpn_mstrain-poly_3x_coco_bbox_mAP-0.408__segm_mAP-0.37_20200504_163245-42aa3d00.pth’
from mmcv import Config
cfg = Config.fromfile('./configs/faster_rcnn/faster_rcnn_r50_caffe_fpn_mstrain_1x_coco.py')
from mmdet.apis import set_random_seed
# Modify dataset type and path
cfg.dataset_type = 'KittiTinyDataset'
cfg.data_root = 'kitti_tiny/'
cfg.data.test.type = 'KittiTinyDataset'
cfg.data.test.data_root = 'kitti_tiny/'
cfg.data.test.ann_file = 'train.txt'
cfg.data.test.img_prefix = 'training/image_2'
cfg.data.train.type = 'KittiTinyDataset'
cfg.data.train.data_root = 'kitti_tiny/'
cfg.data.train.ann_file = 'train.txt'
cfg.data.train.img_prefix = 'training/image_2'
cfg.data.val.type = 'KittiTinyDataset'
cfg.data.val.data_root = 'kitti_tiny/'
cfg.data.val.ann_file = 'val.txt'
cfg.data.val.img_prefix = 'training/image_2'
# modify num classes of the model in box head
cfg.model.roi_head.bbox_head.num_classes = 3
# We can still use the pre-trained Mask RCNN model though we do not need to
# use the mask branch
cfg.load_from = 'checkpoints/mask_rcnn_r50_caffe_fpn_mstrain-poly_3x_coco_bbox_mAP-0.408__segm_mAP-0.37_20200504_163245-42aa3d00.pth'
cfg.work_dir = './tutorial_exps'
# The original learning rate (LR) is set for 8-GPU training.
# We divide it by 8 since we only use one GPU.
cfg.optimizer.lr = 0.02 / 8
cfg.lr_config.warmup = None
cfg.log_config.interval = 10
# Change the evaluation metric since we use customized dataset.
cfg.evaluation.metric = 'mAP'
# We can set the evaluation interval to reduce the evaluation times
cfg.evaluation.interval = 12
# We can set the checkpoint saving interval to reduce the storage cost
cfg.checkpoint_config.interval = 12
# Set seed thus the results are more reproducible
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)
五、开始你的第一次mmdetection训练
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.apis import train_detector
# Build dataset
datasets = [build_dataset(cfg.data.train)]
# Build the detector
model = build_detector(
cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg'))
# Add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
# Create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
train_detector(model, datasets, cfg, distributed=False, validate=True)
六、开始你的第一次mmdetection推理
img = mmcv.imread('kitti_tiny/training/image_2/000068.jpeg')
model.cfg = cfg
from mmdet.apis import inference_detector, init_detector, show_result_pyplot
result = inference_detector(model, img)
show_result_pyplot(model, img, result)
上面这些python代码需要放到一个文件中执行,当我们训练完成之后,可以单独使用下面的代码进行推理。
import os
import mmcv
from mmdet.apis import inference_detector, init_detector, show_result_pyplot
from mmcv import Config
cfg = Config.fromfile('./configs/faster_rcnn/faster_rcnn_r50_caffe_fpn_mstrain_1x_coco.py')
checkpoint_file = 'tutorial_exps/epoch_12.pth'
cfg.model.roi_head.bbox_head.num_classes = 3
# 根据配置文件和 checkpoint 文件构建模型
model = init_detector(cfg, checkpoint_file, device='cuda:0')
file_path="/home/mby/mmdetection/data/kitti_tiny/training/image_2/"
for root, dirs, files in os.walk(file_path, topdown=False):
for image_id in files:
filename = file_path + image_id
image = mmcv.imread(filename)
result = inference_detector(model, image)
show_result_pyplot(model, image, result)