【mmdetection】使用kitti数据集进行训练

目录

一、环境配置

二、Kitti数据集准备

三、仓库中需要修改的文件

3.1 mmdet/datasets中添加kitti.py,内容如下

3.2 修改mmdet/datasets/__init__.py,修改位置已注释标出

3.3 configs/_base_/datasets中添加kitti_detection.py,内容如下

3.4 修改mmdet\core\evaluation文件夹中的class_names.py文件

3.5 修改mmdet\core\evaluation/__init__.py文件,修改位置已注释标出

3.6 修改configs文件夹中需要使用的配置文件

四、训练

五、可视化

5.1 修改mmdet\apis文件夹中的inference.py文件

5.2 使用visualization.py(见下)可视化。(注:需要将visualization.py放在到mmdetection目录下)

5.3 或者使用DetVisGUI可视化

六、参考

七、附录

7.1 kitti标签类别合并

7.2 kitti转voc(可能有问题,不建议使用)


写在前面:官方给了一个demo程序将Kitti转为COCO格式,但是加载数据、修改配置、训练、测试、可视化这些东西都都放在一起总觉得不舒服,用那个比较好,于是把官方的示例改成了一个新的数据集kitti。

一、环境配置

mmdet 2.7.0

mmcv 1.2.1

克隆仓库中的源码,并在目录下创建data文件夹。按照get_started.md文件进行配置,不再赘述。

二、Kitti数据集准备

按照以下文件夹结构准备数据。

mmdetection
├── mmdet
├── tools
├── configs
├── data
│   ├── kitti
│   │   ├── training
│   │   │   ├── image_2
│   │   │   ├── label_2
│   │   ├── train.txt
│   │   ├── val.txt
│   │   ├── trainval.txt

三、仓库中需要修改的文件

3.1 mmdet/datasets中添加kitti.py,内容如下

import os.path as osp

import mmcv
import numpy as np

from mmdet.datasets.builder import DATASETS
from mmdet.datasets.custom import CustomDataset


@DATASETS.register_module()
class KittiDataset(CustomDataset):
    CLASSES = ('Car', 'Pedestrian', 'Cyclist')

    def load_annotations(self, ann_file):
        cat2label = {k: i for i, k in enumerate(self.CLASSES)}
        # load image list from file
        image_list = mmcv.list_from_file(self.ann_file)

        data_infos = []
        # convert annotations to middle format
        for image_id in image_list:
            filename = f'{self.img_prefix}/{image_id}.png'
            image = mmcv.imread(filename)
            height, width = image.shape[:2]

            data_info = dict(filename=f'{image_id}.png', width=width, height=height)

            # load annotations
            label_prefix = self.img_prefix.replace('image_2', 'label_2')
            lines = mmcv.list_from_file(osp.join(label_prefix, f'{image_id}.txt'))

            content = [line.strip().split(' ') for line in lines]
            bbox_names = [x[0] for x in content]
            bboxes = [[float(info) for info in x[4:8]] for x in content]

            gt_bboxes = []
            gt_labels = []
            gt_bboxes_ignore = []
            gt_labels_ignore = []

            # filter 'DontCare'
            for bbox_name, bbox in zip(bbox_names, bboxes):
                if bbox_name in cat2label:
                    gt_labels.append(cat2label[bbox_name])
                    gt_bboxes.append(bbox)
                else:
                    gt_labels_ignore.append(-1)
                    gt_bboxes_ignore.append(bbox)

            data_anno = dict(
                bboxes=np.array(gt_bboxes, dtype=np.float32).reshape(-1, 4),
                labels=np.array(gt_labels, dtype=np.long),
                bboxes_ignore=np.array(gt_bboxes_ignore,
                                       dtype=np.float32).reshape(-1, 4),
                labels_ignore=np.array(gt_labels_ignore, dtype=np.long))

            data_info.update(ann=data_anno)
            data_infos.append(data_info)

        return data_infos

3.2 修改mmdet/datasets/__init__.py,修改位置已注释标出

from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
from .cityscapes import CityscapesDataset
from .coco import CocoDataset
from .custom import CustomDataset
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
                               RepeatDataset)
from .deepfashion import DeepFashionDataset
from .lvis import LVISDataset, LVISV1Dataset, LVISV05Dataset
from .samplers import DistributedGroupSampler, DistributedSampler, GroupSampler
from .utils import replace_ImageToTensor
from .voc import VOCDataset
from .wider_face import WIDERFaceDataset
from .xml_style import XMLDataset
from .kitti import KittiDataset            #新加

__all__ = [
    #下面的KittiDataset为新加
    'KittiDataset','CustomDataset', 'XMLDataset', 'CocoDataset', 'DeepFashionDataset',
    'VOCDataset', 'CityscapesDataset', 'LVISDataset', 'LVISV05Dataset',
    'LVISV1Dataset', 'GroupSampler', 'DistributedGroupSampler',
    'DistributedSampler', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
    'ClassBalancedDataset', 'WIDERFaceDataset', 'DATASETS', 'PIPELINES',
    'build_dataset', 'replace_ImageToTensor'
]

3.3 configs/_base_/datasets中添加kitti_detection.py,内容如下

# dataset settings
dataset_type = 'KittiDataset'
data_root = 'data/kitti/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(1333, 800),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]
data = dict(
    samples_per_gpu=8,
    workers_per_gpu=8,
    train=dict(
        type=dataset_type,
        ann_file=data_root + 'train.txt',
        img_prefix=data_root + 'training/image_2',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        ann_file=data_root + 'val.txt',
        img_prefix=data_root + 'training/image_2',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        ann_file=data_root + 'val.txt',
        img_prefix=data_root + 'training/image_2',
        pipeline=test_pipeline))
evaluation = dict(interval=1, metric='mAP')

3.4 修改mmdet\core\evaluation文件夹中的class_names.py文件

###添加一个函数
def kitti_classes():
    return ['Car', 'Pedestrian', 'Cyclist']

###修改dataset_aliases
dataset_aliases = {
    'kitti':['kitti'],        #添加kitti数据集
    'voc': ['voc', 'pascal_voc', 'voc07', 'voc12'],
    'imagenet_det': ['det', 'imagenet_det', 'ilsvrc_det'],
    'imagenet_vid': ['vid', 'imagenet_vid', 'ilsvrc_vid'],
    'coco': ['coco', 'mscoco', 'ms_coco'],
    'wider_face': ['WIDERFaceDataset', 'wider_face', 'WDIERFace'],
    'cityscapes': ['cityscapes']
}

3.5 修改mmdet\core\evaluation/__init__.py文件,修改位置已注释标出

from .class_names import (cityscapes_classes, coco_classes, dataset_aliases,
                          get_classes, imagenet_det_classes,
                          imagenet_vid_classes, voc_classes,
                          kitti_classes)    #新加
from .eval_hooks import DistEvalHook, EvalHook
from .mean_ap import average_precision, eval_map, print_map_summary
from .recall import (eval_recalls, plot_iou_recall, plot_num_recall,
                     print_recall_summary)

__all__ = [
    'kitti_classes',        #新加
    'voc_classes', 'imagenet_det_classes', 'imagenet_vid_classes',
    'coco_classes', 'cityscapes_classes', 'dataset_aliases', 'get_classes',
    'DistEvalHook', 'EvalHook', 'average_precision', 'eval_map',
    'print_map_summary', 'eval_recalls', 'print_recall_summary',
    'plot_num_recall', 'plot_iou_recall'
]

3.6 修改configs文件夹中需要使用的配置文件

这里使用的是configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py,只需要将_base_中的coco_detection.py改成kitti_detection.py,然后将类别数修改为3类即可。

四、训练

训练参考传送门,初步尝试使用的命令如下

CUDA_VISIBLE_DEVICES=6 nohup python tools/train.py configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py --gpus=1 --work-dir=fcos_output/TEST>fcos_r50_caffe_fpn_4x4_1x_coco.log 2>&1 &

五、可视化

5.1 修改mmdet\apis文件夹中的inference.py文件

###将第48行
model.CLASSES = get_classes('coco')
###修改为
model.CLASSES = get_classes('kitti')

5.2 使用visualization.py(见下)可视化。(注:需要将visualization.py放在到mmdetection目录下)

from mmdet.apis import init_detector, inference_detector, show_result_pyplot
import mmcv

config_file = 'configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py'
# download the checkpoint from model zoo and put it in `checkpoints/`
checkpoint_file = 'fcos_output/fcos_r50_caffe_fpn_4x4_1x_coco/epoch_12.pth'

# build the model from a config file and a checkpoint file
model = init_detector(config_file, checkpoint_file, device='cuda:2')

# test a single image
img = 'demo.jpg'
result = inference_detector(model, img)
#print(result)
# show the results
show_result_pyplot(model, img, result)

5.3 或者使用DetVisGUI可视化

我使用的是DetVisGUI_test.py,直接显示模型预测结果,需要做的修改如下所示,然后运行下面命令即可。(注:需要将DetVisGUI_test.py、epoch_9.pth和data中的test_images文件夹拷贝到mmdetection下)

python DetVisGUI_test.py configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py epoch_9.pth data/test_images
###添加类
class KITTI_dataset:

    def __init__(self, cfg, args):
        self.dataset = 'KITTI'
        self.img_root = args.img_root
        self.config_file = args.config
        self.checkpoint_file = args.ckpt
        self.mask = False
        self.device = args.device

        # according json to get category, image list, and annotations.
        self.img_list = self.get_img_list()

        # coco categories
        self.aug_category = aug_category(['Car', 'Pedestrian', 'Cyclist'])


    def get_img_list(self):
        img_list = list()
        for image in sorted(os.listdir(self.img_root)):
            img_list.append(image)

        return img_list

    def get_img_by_name(self, name):
        img = Image.open(os.path.join(self.img_root, name)).convert('RGB')
        return img

    def get_img_by_index(self, idx):
        img = Image.open(os.path.join(self.img_root,
                                      self.img_list[idx])).convert('RGB')
        return img

###vis_tool的init中COCO改为新加的KITTI
#self.data_info = COCO_dataset(cfg, self.args)
self.data_info = KITTI_dataset(cfg, self.args)

六、参考

https://blog.csdn.net/gaoyi135/article/details/90613895

https://blog.csdn.net/xiangxianghehe/article/details/89812058#commentsedit

https://blog.csdn.net/jesse_mx/article/details/65634482

https://blog.csdn.net/yapifeitu/article/details/105884203?utm_medium=distribute.pc_relevant.none-task-blog-searchFromBaidu-5.control&depth_1-utm_source=distribute.pc_relevant.none-task-blog-searchFromBaidu-5.control

七、附录

7.1 kitti标签类别合并

kitti包括九个类别,分别是'Car', 'Van', 'Truck','Pedestrian', 'Person_sitting', 'Cyclist','Tram',  'Misc' or  'DontCare',这里将 ‘Van’, ‘Truck’, ‘Tram’ 合并到 ‘Car’ 类别中去,将 ‘Person_sitting’ 合并到 ‘Pedestrian’ 类别中去,‘Misc’ 和 ‘Dontcare’ 这两类直接忽略,最终只保留三个类别,分别是'Car', 'Pedestrian'和'Cyclist'。

这里使用的是这位博主博客中的modify_annotations_txt.py工具,源码如下,注意!!!该工具直接在标签txt文件上修改,使用之前做好备份!!!。运行工具之前需要将kitti中的label_2文件夹拷贝到data/VOCdevit/VOC2007当中并重命名为Labels,modify_annotations_txt.py工具同样置于此文件夹下,然后执行如下命令:python modify_annotations_txt.py,就可以将类别合并,该博主博客中有运行前后对比,这里不再附加。

# modify_annotations_txt.py
import glob
import string

txt_list = glob.glob('./Labels/*.txt') # 存储Labels文件夹所有txt文件路径
def show_category(txt_list):
    category_list= []
    for item in txt_list:
        try:
            with open(item) as tdf:
                for each_line in tdf:
                    labeldata = each_line.strip().split(' ') # 去掉前后多余的字符并把其分开
                    category_list.append(labeldata[0]) # 只要第一个字段,即类别
        except IOError as ioerr:
            print('File error:'+str(ioerr))
    print(set(category_list)) # 输出集合

def merge(line):
    each_line=''
    for i in range(len(line)):
        if i!= (len(line)-1):
            each_line=each_line+line[i]+' '
        else:
            each_line=each_line+line[i] # 最后一条字段后面不加空格
    each_line=each_line+'\n'
    return (each_line)

print('before modify categories are:\n')
show_category(txt_list)

for item in txt_list:
    new_txt=[]
    try:
        with open(item, 'r') as r_tdf:
            for each_line in r_tdf:
                labeldata = each_line.strip().split(' ')
                if labeldata[0] in ['Truck','Van','Tram']: # 合并汽车类
                    labeldata[0] = labeldata[0].replace(labeldata[0],'Car')
                if labeldata[0] == 'Person_sitting': # 合并行人类
                    labeldata[0] = labeldata[0].replace(labeldata[0],'Pedestrian')
                if labeldata[0] == 'DontCare': # 忽略Dontcare类
                    continue
                if labeldata[0] == 'Misc': # 忽略Misc类
                    continue
                new_txt.append(merge(labeldata)) # 重新写入新的txt文件
        with open(item,'w+') as w_tdf: # w+是打开原文件将内容删除,另写新内容进去
            for temp in new_txt:
                w_tdf.write(temp)
    except IOError as ioerr:
        print('File error:'+str(ioerr))

print('\nafter modify categories are:\n')
show_category(txt_list)

7.2 kitti转voc(可能有问题,不建议使用)

格式转换参考另一位博主的博客,我只是将JPEGImages中的图片格式由jpg改为了png,另外我的存储标记信息的文件夹是Labels,也在下面代码中有体现。我将其命名为kitti2voc.py,与Labels位于同一目录下,执行以下命令即可完成转换:python kitti2voc.py

'''author:nike hu'''
# -*- coding: utf-8 -*-

import shutil
import os
import cv2

headstr = """\
<annotation>
    <folder>VOC2007</folder>
    <filename>%06d.png</filename>
    <source>
        <database>My Database</database>
        <annotation>PASCAL VOC2007</annotation>
        <image>flickr</image>
        <flickrid>NULL</flickrid>
    </source>
    <owner>
        <flickrid>NULL</flickrid>
        <name>company</name>
    </owner>
    <size>
        <width>%d</width>
        <height>%d</height>
        <depth>%d</depth>
    </size>
    <segmented>0</segmented>
"""
objstr = """\
    <object>
        <name>%s</name>
        <pose>Unspecified</pose>
        <truncated>0</truncated>
        <difficult>0</difficult>
        <bndbox>
            <xmin>%d</xmin>
            <ymin>%d</ymin>
            <xmax>%d</xmax>
            <ymax>%d</ymax>
        </bndbox>
    </object>
"""

tailstr = '''\
</annotation>
'''




def writexml(idx, head, bbxes, tail):
    filename = ("Annotations/%06d.xml" % (idx))
    f = open(filename, "w")
    f.write(head)
    for bbx in bbxes:
        f.write(objstr % (bbx[-1], bbx[0], bbx[1], bbx[0] + bbx[2], bbx[1] + bbx[3]))
    f.write(tail)
    f.close()


def clear_dir():
    if shutil.os.path.exists(('Annotations')):
        shutil.rmtree(('Annotations'))
    if shutil.os.path.exists(('ImageSets')):
        shutil.rmtree(('ImageSets'))
    # if shutil.os.path.exists(('JPEGImages')): # 因为我们已经将所有图片放到这个文件夹里面了,所以不需要再创建了
    #     shutil.rmtree(('JPEGImages'))

    shutil.os.mkdir(('Annotations'))
    shutil.os.makedirs(('ImageSets/Main'))
    # shutil.os.mkdir(('JPEGImages'))


def excute_datasets():
    '''在Main文件夹下面要创建四个文件,trainval是总样本的百分之九十,train是总数据的百分之七十,val是总数据样本的百分之20,剩下的百分之10是测试样本'''
    ftrainval = open(('ImageSets/Main/' + 'trainval' + '.txt'), 'a')
    ftrain = open(('ImageSets/Main/' + 'train' + '.txt'), 'a')
    fval = open(('ImageSets/Main/' + 'val' + '.txt'), 'a')
    ftest = open(('ImageSets/Main/' + 'test' + '.txt'), 'a')
    images = './JPEGImages/' # 这是存储图片的位置
    txtfile = './Labels/' # 这个是是存储标记信息的文件夹
    txtlist = os.listdir(txtfile)
    lenfile = len(txtlist) # 这个是标记的信息的总的文件数量
    count = 1 # 统计正在处理的数量
    for txtname in txtlist:
        txt_path = os.path.join(txtfile, txtname)
        image_path = os.path.join(images, txtname.split('.')[0] + '.png')  # 这里是图片的路径
        im = cv2.imread(image_path)  # 读取图片
        if im is None:  # 如果不存在这张照片,跳过
            continue
        head = headstr % (int(txtname.split('.')[0]), im.shape[1], im.shape[0], im.shape[2]) # xml文件的头部分
        boxes = []
        with open(txt_path, 'r') as f:
            while True:
                txt_content = f.readline().split(' ')
                if txt_content[0] == '':
                    break
                label_name = txt_content[0]
                if label_name == 'Misc' or label_name == 'DontCare': # 如果是这两类就去掉
                    continue
                box = [float(x) for x in txt_content[4:8]] # 这个是坐标
                box.append(label_name) # 把每个坐标对应的标签加入
                boxes.append(box)
        writexml(int(txtname.split('.')[0]), head, boxes, tailstr)
        if count < 0.9 * lenfile: # 总样本的百分之90部分存到trainval
            ftrainval.write('%06d\n' % (int(txtname.split('.')[0])))
            if count < 0.7 * lenfile: # 总样本的百分之70存入train
                ftrain.write('%06d\n' % (int(txtname.split('.')[0])))
            else: # 在0.7到0.9之间的数据存入val文件
                fval.write('%06d\n' % (int(txtname.split('.')[0])))
        else:
            ftest.write('%06d\n' % (int(txtname.split('.')[0])))
        count += 1
    ftrain.close() # 运行的时候出现过没有存进去的情况,原因是数据在内存中,还没有存在磁盘中,一般程序运行完会将数据放到磁盘中,或者用close语句
    ftest.close()
    ftrainval.close()
    fval.close()



if __name__ == '__main__':
    clear_dir()
    idx = excute_datasets()
    print('Complete...')


  • 3
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
MMDetection是一个用于目标检测任务的开源框架,它可以用于各种不同的数据集评估,包括KITTIKITTI是一个常用的自动驾驶领域的数据集,包含大量的图像和对应的目标标注信息,适用于目标检测算法的评估。 在使用MMDetection进行KITTI评价时,首先需要准备KITTI数据集,并根据其提供的标注信息进行数据预处理,将其转换为MMDetection所需的格式。然后,可以使用MMDetection提供的训练和测试接口进行模型的训练和测试。 在训练过程中,可以选择使用MMDetection提供的不同网络结构和优化算法进行训练。通过迭代训练,模型可以学习到KITTI数据集中目标的特征和位置信息。 在测试过程中,可以使用训练好的模型对KITTI数据集中的图像进行目标检测。MMDetection会将检测到的目标与标注信息进行比较,计算出一系列评价指标,如精确率、召回率、平均精确率等。 根据MMDetectionKITTI的评价结果,可以评估出模型在KITTI数据集上的性能表现。通过比较不同模型的评价结果,可以选择最适合KITTI数据集的目标检测模型。同时,也可以通过观察评价结果来了解模型在不同类别目标上的性能差异,指导进一步的模型改进和优化。 综上所述,MMDetectionKITTI评价提供了方便且有效的工具,可以帮助研究者和工程师评估目标检测模型在KITTI数据集上的性能,为自动驾驶等相关应用提供支持。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值