MMDetection学习笔记-06使用MMDetection训练监测模型

MMDetection中模型大多给予coco数据集进行训练。coco数据集包含80种物体。如果我们希望模型检测到其它新类型的物体,就需要使用自定义数据集来训练模型。MMDetection支持使用自定义模型训练监测模型。
训练新模型通常有三个步骤:

  1. 支持新数据集
  2. 修改配置文件
  3. 训练模型

MMDetection有三种来支持新数据集:

  1. 将数据集整理为coco格式
  2. 将数据集整理为中间格式
  3. 直接实现新数据集的支持

这里将使用【2. 将数据集整理为中间格式】来表示数据集。
kitti_tiny的数据集见链接:
链接: https://pan.baidu.com/s/1xlcOmMwUHjoSYWIP1tCzFg 提取码: hhe8
kitti_tiny数据集文件结构:

kitti_tiny
├── training
│   ├── image_2
│   │   ├── 000000.jpeg
│   │   ├── 000001.jpeg
│   │   ├── 000002.jpeg
│   │   ├── 000003.jpeg
│   │   ├── 000004.jpeg
│   │   ├── 000005.jpeg
│	│	│—— ......
│   │   ├── 000074.jpeg
│   └── label_2
│       ├── 000000.txt
│       ├── 000001.txt
│       ├── 000002.txt
│       ├── 000003.txt
│       ├── 000004.txt
│       ├── 000005.txt
│		│—— ......
│       ├── 000074.txt
├── train.txt 	#train.txt包含000000,000001,......,000049
└── val.txt		#val.txt包含000050,000051,......,000074

kitti_tiny数据集所放的位置
下面图片展示了000073.jpg:
000073.jpg
接下来看000073.jpg对应的标注000073.txt。文档中一行代表一个物体的标注。第一行pedestrian代表一个行人。“237.23 173.70 312.33 365.33”代表坐标。其它标注类似。DontCare表示很远地方的物体,他们可能很拥挤,或者特别小。识别起来可能会特别困难,所以就不考虑这个框框中的物体了。
According to the KITTI’s documentation, the first column indicates the class of the object, and the 5th to 8th columns indicates the bboxes. We need to read annotations of each image and convert them into middle format MMDetection accept is as below:

Pedestrian 0.00 0 -2.62 237.23 173.70 312.33 365.33 1.58 0.66 0.53 -2.99 1.60 6.32 -3.05
Pedestrian 0.00 1 0.80 189.46 158.23 256.19 356.44 1.70 0.61 0.51 -3.62 1.58 6.54 0.31
Pedestrian 0.00 0 0.45 752.95 164.08 791.19 288.78 1.75 0.63 0.51 2.28 1.63 10.51 0.65
Cyclist 0.00 0 1.78 444.66 170.48 485.70 241.86 1.64 0.57 2.00 -3.55 1.60 17.61 1.58
Cyclist 0.00 0 1.65 494.34 168.08 517.01 223.73 1.80 0.60 1.85 -3.54 1.66 24.31 1.51
Pedestrian 0.00 0 -2.07 546.73 177.07 560.52 214.88 1.53 0.61 0.73 -2.41 1.71 29.83 -2.15
Pedestrian 0.00 0 -2.02 535.68 174.41 549.63 214.38 1.61 0.54 0.87 -2.86 1.68 29.55 -2.12
DontCare -1 -1 -10 596.02 166.69 615.85 203.19 -1 -1 -1 -1000 -1000 -1000 -10

接下来要看MMDetection中间数据集的格式:

#首先,它是一个大的列表。列表中每一个项目都是一个图片。每个图片对应一个字典。这个字典包含了图片的文件名filename、宽度width、高度height、标注ann(annotation)。
#ann中包含了所有类别的标注。假设图片中有n个物体,那么我们需要提供一个n*4的数组bboxes。这个数组包含所有边界框的坐标。并提供一个长度为n的向量labels,用来标注每一个物体的类别。
# bboxes_ignore和labels_ignore就是之前提到的DontCare。需要将DontCare填写到其中。
#在了解MMDetection中间数据集的格式和KITTI的数据集的格式,就可以将KITTI数据集转换为中间数据集的格式了。
[
    {
        'filename': 'a.jpg',
        'width': 1280,
        'height': 720,
        'ann': {
            'bboxes': <np.ndarray> (n, 4),
            'labels': <np.ndarray> (n, ),
            'bboxes_ignore': <np.ndarray> (k, 4), (optional field)
            'labels_ignore': <np.ndarray> (k, 4) (optional field)
        }
    },
    ...
]

接下来的代码文件位于demo目录下。demo与checkpoints是同级
接下来是KITTI数据集转换为MMDetection中间数据集的代码:

# encoding:utf-8
import os.path as osp

import mmcv
import numpy as np


def convert_titti_to_middle(ann_file, out_file, img_prefix):
    CLASSES = ('Car', 'Pedestrian', 'Cyclist')
    # 类别反差表
    cat2label = {k: i for i, k in enumerate(CLASSES)}
    # load image list from file
    image_list = mmcv.list_from_file(ann_file)
    data_infos = []
    # convert annotations to middle format
    for image_id in image_list:
        filename = f'{img_prefix}/{image_id}.jpeg'
        image = mmcv.imread(filename)
        height, width = image.shape[:2]
        # A picture is stored in a dictionary
        data_info = dict(filename=f'{image_id}.jpeg', width=width, height=height)

        # load annotations
        label_prefix = img_prefix.replace('image_2', 'label_2')
        lines = mmcv.list_from_file(osp.join(label_prefix, f'{image_id}.txt'))

        content = [line.strip().split(' ') for line in lines]
        bbox_names = [x[0] for x in content]
        bboxes = [[float(info) for info in x[4:8]] for x in content]

        gt_bboxes = []
        gt_labels = []
        gt_bboxes_ignore = []
        gt_labels_ignore = []

        # filter 'DontCare'
        for bbox_name, bbox in zip(bbox_names, bboxes):
            if bbox_name in cat2label:
                gt_labels.append(cat2label[bbox_name])
                gt_bboxes.append(bbox)
            else:
                gt_labels_ignore.append(-1)
                gt_bboxes_ignore.append(bbox)
        # 将标注信息(坐标和标签)转换为nparray
        data_anno = dict(
            bboxes=np.array(gt_bboxes, dtype=np.float32).reshape(-1, 4),
            labels=np.array(gt_labels, dtype=np.long),
            bboxes_ignore=np.array(gt_bboxes_ignore,
                                   dtype=np.float32).reshape(-1, 4),
            labels_ignore=np.array(gt_labels_ignore, dtype=np.long))

        data_info.update(ann=data_anno)
        # 所有图片和标注信息存储在一个列表中
        data_infos.append(data_info)
    mmcv.dump(data_infos, out_file)
    print()

if __name__ == '__main__':

    convert_titti_to_middle(ann_file="../kitti_tiny/train.txt", out_file="../kitti_tiny/train_middle.pkl",
                            img_prefix="../kitti_tiny/training/image_2")

    convert_titti_to_middle(ann_file="../kitti_tiny/val.txt", out_file="../kitti_tiny/val_middle.pkl",
                            img_prefix="../kitti_tiny/training/image_2")

接下来是修改配置文件的参数

选用faster rcnn模型。对应的checkpoints下载地址是:https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_caffe_fpn_1x_coco/faster_rcnn_r50_caffe_fpn_1x_coco_bbox_mAP-0.378_20200504_180032-c5925ee5.pth

checkpoints文件目录如下所示:
在这里插入图片描述
接下来是编写代码加载原先的配置文件,并在原先的配置文件上修改相应的参数:

from mmcv import Config
from mmdet.apis import set_random_seed

cfg = Config.fromfile('../configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py')

# Modify dataset type and path
cfg.dataset_type = 'CustomDataset' #首先要把数据集改成CustomDataset,这个代表MMDetection的中间数据的格式
cfg.data_root = '../kitti_tiny/'# 修改数据存储的路径。因为demo文件夹与kitti_tiny文件夹是同级,所以需要使用两个“.”
cfg.classes = ('Car', 'Pedestrian', 'Cyclist')#指明数据集中每个物体的类别名称。在cfg下修改是不会直接生效的,还是需要在cfg.data.[train|test|val].classes = ('Car', 'Pedestrian', 'Cyclist')下进行修改,才能生效。train,test和val都需要修改。

cfg.data.test.type = 'CustomDataset'
cfg.data.test.data_root = '../kitti_tiny/'
cfg.data.test.ann_file = 'train_middle.pkl'#还需要指明刚刚保存的中间数据集的路径和名称。这里测试集也用了训练集的中间数据集,主要是为了看在训练集上的表现。
cfg.data.test.img_prefix = 'training/image_2'
cfg.data.test.classes = ('Car', 'Pedestrian', 'Cyclist')

cfg.data.train.type = 'CustomDataset'
cfg.data.train.data_root = '../kitti_tiny/'
cfg.data.train.ann_file = 'train_middle.pkl'
cfg.data.train.img_prefix = 'training/image_2'
cfg.data.train.classes = ('Car', 'Pedestrian', 'Cyclist')

cfg.data.val.type = 'CustomDataset'
cfg.data.val.data_root = '../kitti_tiny/'
cfg.data.val.ann_file = 'val_middle.pkl'
cfg.data.val.img_prefix = 'training/image_2'
cfg.data.val.classes = ('Car', 'Pedestrian', 'Cyclist')

# modify num classes of the model in box head
cfg.model.roi_head.bbox_head.num_classes = 3 #classes = ('Car', 'Pedestrian', 'Cyclist')

cfg.load_from = "../checkpoints/faster_rcnn/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth" #使用预训练好的faster rcnn模型用于fine tuning

cfg.work_dir = './' # Set up working dir to save files and logs.

# The original learning rate (LR) is set for 8-GPU training.
# We divide it by 8 since we only use one GPU.
cfg.optimizer.lr = 0.02 / 8
cfg.lr_config.warmup = None
cfg.log_config.interval = 10

# Change the evaluation metric since we use customized dataset.
cfg.evaluation.metric = 'mAP'
# We can set the evaluation interval to reduce the evaluation times
cfg.evaluation.interval = 12
# We can set the checkpoint saving interval to reduce the storage cost
cfg.checkpoint_config.interval = 12

# Set seed thus the results are more reproducible
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)

# We can initialize the logger for training and have a look
# at the final config used for training
print(f'Config:\n{cfg.pretty_text}')
# 保存模型的各种参数(一定要记得嗷)
cfg.dump(F'{cfg.work_dir}/customformat_kitti.py')

import joblib
joblib.dump(cfg, "./cfg.dump")

####################################
# 训练新模型
# 根据配置文件构建数据集,监测模型,并完成训练
import mmcv
from mmdet.apis import train_detector
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
import os.path as osp

datasets = [build_dataset(cfg.data.train)] #构件数据集
model = build_detector(cfg.model) #构建监测模型
model.CLASSES = datasets[0].CLASSES #添加类别文字属性来提高可视化效果

#创建工作目录并训练模型
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 
train_detector(model, datasets, cfg, distributed=False, validate=True)
joblib.dump(model, "./model.dump")
print()
# 经过一段时间训练后的评估
+------------+-----+------+--------+-------+
| class      | gts | dets | recall | ap    |
+------------+-----+------+--------+-------+
| Car        | 62  | 151  | 0.919  | 0.822 |
| Pedestrian | 13  | 55   | 0.923  | 0.771 |
| Cyclist    | 7   | 62   | 0.571  | 0.081 |
+------------+-----+------+--------+-------+
| mAP        |     |      |        | 0.558 |
+------------+-----+------+--------+-------+

模型评估
切换到demo目录,执行下列代码:

 python ../tools/test.py customformat_kitti.py latest.pth  --eval mAP

在这里插入图片描述
模型评估结果如下:
在这里插入图片描述
测试训练好的模型:

# encoding:utf-8

import joblib
import mmcv

from mmdet.apis import inference_detector, show_result_pyplot

cfg = joblib.load("./cfg.dump")
model = joblib.load("./model.dump")
model.cfg = cfg

for i in range(60, 70):
    img = mmcv.imread('../kitti_tiny/training/image_2/0000' + str(i) + '.jpeg')
    result = inference_detector(model, img)
    show_result_pyplot(model, img, result)

下图展示了其中一个训练结果:
在这里插入图片描述

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

www5599667788

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值