MMDetection中模型大多给予coco数据集进行训练。coco数据集包含80种物体。如果我们希望模型检测到其它新类型的物体,就需要使用自定义数据集来训练模型。MMDetection支持使用自定义模型训练监测模型。
训练新模型通常有三个步骤:
- 支持新数据集
- 修改配置文件
- 训练模型
MMDetection有三种来支持新数据集:
- 将数据集整理为coco格式
- 将数据集整理为中间格式
- 直接实现新数据集的支持
这里将使用【2. 将数据集整理为中间格式】来表示数据集。
kitti_tiny的数据集见链接:
链接: https://pan.baidu.com/s/1xlcOmMwUHjoSYWIP1tCzFg 提取码: hhe8
kitti_tiny数据集文件结构:
kitti_tiny
├── training
│ ├── image_2
│ │ ├── 000000.jpeg
│ │ ├── 000001.jpeg
│ │ ├── 000002.jpeg
│ │ ├── 000003.jpeg
│ │ ├── 000004.jpeg
│ │ ├── 000005.jpeg
│ │ │—— ......
│ │ ├── 000074.jpeg
│ └── label_2
│ ├── 000000.txt
│ ├── 000001.txt
│ ├── 000002.txt
│ ├── 000003.txt
│ ├── 000004.txt
│ ├── 000005.txt
│ │—— ......
│ ├── 000074.txt
├── train.txt #train.txt包含000000,000001,......,000049
└── val.txt #val.txt包含000050,000051,......,000074
下面图片展示了000073.jpg:
接下来看000073.jpg对应的标注000073.txt。文档中一行代表一个物体的标注。第一行pedestrian代表一个行人。“237.23 173.70 312.33 365.33”代表坐标。其它标注类似。DontCare表示很远地方的物体,他们可能很拥挤,或者特别小。识别起来可能会特别困难,所以就不考虑这个框框中的物体了。
According to the KITTI’s documentation, the first column indicates the class of the object, and the 5th to 8th columns indicates the bboxes. We need to read annotations of each image and convert them into middle format MMDetection accept is as below:
Pedestrian 0.00 0 -2.62 237.23 173.70 312.33 365.33 1.58 0.66 0.53 -2.99 1.60 6.32 -3.05
Pedestrian 0.00 1 0.80 189.46 158.23 256.19 356.44 1.70 0.61 0.51 -3.62 1.58 6.54 0.31
Pedestrian 0.00 0 0.45 752.95 164.08 791.19 288.78 1.75 0.63 0.51 2.28 1.63 10.51 0.65
Cyclist 0.00 0 1.78 444.66 170.48 485.70 241.86 1.64 0.57 2.00 -3.55 1.60 17.61 1.58
Cyclist 0.00 0 1.65 494.34 168.08 517.01 223.73 1.80 0.60 1.85 -3.54 1.66 24.31 1.51
Pedestrian 0.00 0 -2.07 546.73 177.07 560.52 214.88 1.53 0.61 0.73 -2.41 1.71 29.83 -2.15
Pedestrian 0.00 0 -2.02 535.68 174.41 549.63 214.38 1.61 0.54 0.87 -2.86 1.68 29.55 -2.12
DontCare -1 -1 -10 596.02 166.69 615.85 203.19 -1 -1 -1 -1000 -1000 -1000 -10
接下来要看MMDetection中间数据集的格式:
#首先,它是一个大的列表。列表中每一个项目都是一个图片。每个图片对应一个字典。这个字典包含了图片的文件名filename、宽度width、高度height、标注ann(annotation)。
#ann中包含了所有类别的标注。假设图片中有n个物体,那么我们需要提供一个n*4的数组bboxes。这个数组包含所有边界框的坐标。并提供一个长度为n的向量labels,用来标注每一个物体的类别。
# bboxes_ignore和labels_ignore就是之前提到的DontCare。需要将DontCare填写到其中。
#在了解MMDetection中间数据集的格式和KITTI的数据集的格式,就可以将KITTI数据集转换为中间数据集的格式了。
[
{
'filename': 'a.jpg',
'width': 1280,
'height': 720,
'ann': {
'bboxes': <np.ndarray> (n, 4),
'labels': <np.ndarray> (n, ),
'bboxes_ignore': <np.ndarray> (k, 4), (optional field)
'labels_ignore': <np.ndarray> (k, 4) (optional field)
}
},
...
]
接下来的代码文件位于demo目录下。demo与checkpoints是同级
接下来是KITTI数据集转换为MMDetection中间数据集的代码:
# encoding:utf-8
import os.path as osp
import mmcv
import numpy as np
def convert_titti_to_middle(ann_file, out_file, img_prefix):
CLASSES = ('Car', 'Pedestrian', 'Cyclist')
# 类别反差表
cat2label = {k: i for i, k in enumerate(CLASSES)}
# load image list from file
image_list = mmcv.list_from_file(ann_file)
data_infos = []
# convert annotations to middle format
for image_id in image_list:
filename = f'{img_prefix}/{image_id}.jpeg'
image = mmcv.imread(filename)
height, width = image.shape[:2]
# A picture is stored in a dictionary
data_info = dict(filename=f'{image_id}.jpeg', width=width, height=height)
# load annotations
label_prefix = img_prefix.replace('image_2', 'label_2')
lines = mmcv.list_from_file(osp.join(label_prefix, f'{image_id}.txt'))
content = [line.strip().split(' ') for line in lines]
bbox_names = [x[0] for x in content]
bboxes = [[float(info) for info in x[4:8]] for x in content]
gt_bboxes = []
gt_labels = []
gt_bboxes_ignore = []
gt_labels_ignore = []
# filter 'DontCare'
for bbox_name, bbox in zip(bbox_names, bboxes):
if bbox_name in cat2label:
gt_labels.append(cat2label[bbox_name])
gt_bboxes.append(bbox)
else:
gt_labels_ignore.append(-1)
gt_bboxes_ignore.append(bbox)
# 将标注信息(坐标和标签)转换为nparray
data_anno = dict(
bboxes=np.array(gt_bboxes, dtype=np.float32).reshape(-1, 4),
labels=np.array(gt_labels, dtype=np.long),
bboxes_ignore=np.array(gt_bboxes_ignore,
dtype=np.float32).reshape(-1, 4),
labels_ignore=np.array(gt_labels_ignore, dtype=np.long))
data_info.update(ann=data_anno)
# 所有图片和标注信息存储在一个列表中
data_infos.append(data_info)
mmcv.dump(data_infos, out_file)
print()
if __name__ == '__main__':
convert_titti_to_middle(ann_file="../kitti_tiny/train.txt", out_file="../kitti_tiny/train_middle.pkl",
img_prefix="../kitti_tiny/training/image_2")
convert_titti_to_middle(ann_file="../kitti_tiny/val.txt", out_file="../kitti_tiny/val_middle.pkl",
img_prefix="../kitti_tiny/training/image_2")
接下来是修改配置文件的参数
选用faster rcnn模型。对应的checkpoints下载地址是:https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_caffe_fpn_1x_coco/faster_rcnn_r50_caffe_fpn_1x_coco_bbox_mAP-0.378_20200504_180032-c5925ee5.pth
checkpoints文件目录如下所示:
接下来是编写代码加载原先的配置文件,并在原先的配置文件上修改相应的参数:
from mmcv import Config
from mmdet.apis import set_random_seed
cfg = Config.fromfile('../configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py')
# Modify dataset type and path
cfg.dataset_type = 'CustomDataset' #首先要把数据集改成CustomDataset,这个代表MMDetection的中间数据的格式
cfg.data_root = '../kitti_tiny/'# 修改数据存储的路径。因为demo文件夹与kitti_tiny文件夹是同级,所以需要使用两个“.”
cfg.classes = ('Car', 'Pedestrian', 'Cyclist')#指明数据集中每个物体的类别名称。在cfg下修改是不会直接生效的,还是需要在cfg.data.[train|test|val].classes = ('Car', 'Pedestrian', 'Cyclist')下进行修改,才能生效。train,test和val都需要修改。
cfg.data.test.type = 'CustomDataset'
cfg.data.test.data_root = '../kitti_tiny/'
cfg.data.test.ann_file = 'train_middle.pkl'#还需要指明刚刚保存的中间数据集的路径和名称。这里测试集也用了训练集的中间数据集,主要是为了看在训练集上的表现。
cfg.data.test.img_prefix = 'training/image_2'
cfg.data.test.classes = ('Car', 'Pedestrian', 'Cyclist')
cfg.data.train.type = 'CustomDataset'
cfg.data.train.data_root = '../kitti_tiny/'
cfg.data.train.ann_file = 'train_middle.pkl'
cfg.data.train.img_prefix = 'training/image_2'
cfg.data.train.classes = ('Car', 'Pedestrian', 'Cyclist')
cfg.data.val.type = 'CustomDataset'
cfg.data.val.data_root = '../kitti_tiny/'
cfg.data.val.ann_file = 'val_middle.pkl'
cfg.data.val.img_prefix = 'training/image_2'
cfg.data.val.classes = ('Car', 'Pedestrian', 'Cyclist')
# modify num classes of the model in box head
cfg.model.roi_head.bbox_head.num_classes = 3 #classes = ('Car', 'Pedestrian', 'Cyclist')
cfg.load_from = "../checkpoints/faster_rcnn/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth" #使用预训练好的faster rcnn模型用于fine tuning
cfg.work_dir = './' # Set up working dir to save files and logs.
# The original learning rate (LR) is set for 8-GPU training.
# We divide it by 8 since we only use one GPU.
cfg.optimizer.lr = 0.02 / 8
cfg.lr_config.warmup = None
cfg.log_config.interval = 10
# Change the evaluation metric since we use customized dataset.
cfg.evaluation.metric = 'mAP'
# We can set the evaluation interval to reduce the evaluation times
cfg.evaluation.interval = 12
# We can set the checkpoint saving interval to reduce the storage cost
cfg.checkpoint_config.interval = 12
# Set seed thus the results are more reproducible
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)
# We can initialize the logger for training and have a look
# at the final config used for training
print(f'Config:\n{cfg.pretty_text}')
# 保存模型的各种参数(一定要记得嗷)
cfg.dump(F'{cfg.work_dir}/customformat_kitti.py')
import joblib
joblib.dump(cfg, "./cfg.dump")
####################################
# 训练新模型
# 根据配置文件构建数据集,监测模型,并完成训练
import mmcv
from mmdet.apis import train_detector
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
import os.path as osp
datasets = [build_dataset(cfg.data.train)] #构件数据集
model = build_detector(cfg.model) #构建监测模型
model.CLASSES = datasets[0].CLASSES #添加类别文字属性来提高可视化效果
#创建工作目录并训练模型
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
train_detector(model, datasets, cfg, distributed=False, validate=True)
joblib.dump(model, "./model.dump")
print()
# 经过一段时间训练后的评估
+------------+-----+------+--------+-------+
| class | gts | dets | recall | ap |
+------------+-----+------+--------+-------+
| Car | 62 | 151 | 0.919 | 0.822 |
| Pedestrian | 13 | 55 | 0.923 | 0.771 |
| Cyclist | 7 | 62 | 0.571 | 0.081 |
+------------+-----+------+--------+-------+
| mAP | | | | 0.558 |
+------------+-----+------+--------+-------+
模型评估
切换到demo目录,执行下列代码:
python ../tools/test.py customformat_kitti.py latest.pth --eval mAP
模型评估结果如下:
测试训练好的模型:
# encoding:utf-8
import joblib
import mmcv
from mmdet.apis import inference_detector, show_result_pyplot
cfg = joblib.load("./cfg.dump")
model = joblib.load("./model.dump")
model.cfg = cfg
for i in range(60, 70):
img = mmcv.imread('../kitti_tiny/training/image_2/0000' + str(i) + '.jpeg')
result = inference_detector(model, img)
show_result_pyplot(model, img, result)
下图展示了其中一个训练结果: