创建符合MMdetection要求的训练数据集


前言

MMdetectionopenmmlab的一个目标检测开源项目,本文先简述MMdetection的安装,然后针对用户标注的标签文件格式(比如xml、txt)如何转换成符合mmdetection的标签格式。


一、MMdetection安装

如果你是在本地电脑安装mmdetection,那么请确保已经正确安装PythonAnaconda;如果是远程服务器,一般都会默认配置PythonMiniconda

这里参考openmmlab官方的安装步骤:

# 创建环境
conda create -n openmmlab(名字可自取) python=3.8 pytorch==1.10.0 cudatoolkit=10.1 torchvision -c pytorch -y
conda activate openmmlab
pip install openmim
mim install mmcv-full
git clone https://github.com/open-mmlab/mmdetection.git
cd mmdetection
pip install -r requirements/build.txt
pip install -v -e .

二、创建COCO格式数据集

之所以是COCO格式数据集,是因为mmdetection大部分的数据格式是COCO,所以转换成COCO方便训练测试。我这里默认已经标注并划分好了数据集(以yolov5):

NeedTrainImageFold:
	images:
		train: 训练集图片
		val: 验证集图片
		test: 测试集图片
	labels:
		train: 训练集标签
		val: 验证集标签
		test: 测试集标签

如果你的标签文件是xml或者txt,那么就需要将其转换成json,这里贴上开源库,如果你的标签文件是xml格式,请用这个:voc2coco;如果是txt,请用这个:YOLO2COCO

因为我使用txt的格式较多,所以我演示演示一下YOLO2COCO这个库的使用:

首先是把项目克隆下来:git clone https://github.com/RapidAI/YOLO2COCO.git(请确保你已经安装git,这样才能正常克隆)

请看其项目有2个转换成COCO的方法,YOLOV5 -> COCO、YOLOV5 YAML -> COCO
在这里图片描述
我这里主要演示YOLOV5 YAML -> COCO,首先就是你的得创建YAML文件,格式如下:

# 该配置文件和yolov5项目使用的一样

# dataset root dir
path: dataset/YOLOV5_yaml
train:  # train images (relative to 'path')
  - images/train
val: # val images (relative to 'path')
  - images/val

# Classes
nc: 1  # number of classes
names: ['stamp']  # class names

之后在终端输入:

python yolov5_yaml_2_coco.py --yaml_path dataset/YOLOV5_yaml/sample.yaml (你的yaml文件路径)

注意注意

如果你的数据集划分和我上面一样,有test测试集的话,那么运行上述代码不会有test2017文件夹及其标注信息。
没有test2017文件夹
annotations文件夹也没有test标签json文件

我将作者的yolov5_yaml_2_coco.py代码稍微改了一下,直接全复制下面的代码,然后替换yolov5_yaml_2_coco.py即可

# !/usr/bin/env python
# -*- encoding: utf-8 -*-
import argparse
import glob
import json
import os
import shutil
import time
from pathlib import Path

import cv2
import yaml
from tqdm import tqdm


def read_txt(txt_path):
    with open(str(txt_path), 'r', encoding='utf-8') as f:
        data = list(map(lambda x: x.rstrip('\n'), f))
    return data


def mkdir(dir_path):
    Path(dir_path).mkdir(parents=True, exist_ok=True)


def verify_exists(file_path):
    file_path = Path(file_path).resolve()
    if not file_path.exists():
        raise FileNotFoundError(f'The {file_path} is not exists!!!')


class YOLOV5CFG2COCO(object):
    def __init__(self, yaml_path):
        verify_exists(yaml_path)
        with open(yaml_path, 'r', encoding="UTF-8") as f:
            self.data_cfg = yaml.safe_load(f)

        self.root_dir = Path(yaml_path).parent.parent
        self.root_data_dir = Path(self.data_cfg.get('path'))

        self.train_path = self._get_data_dir('train')
        self.val_path = self._get_data_dir('val')
        self.test_path = self._get_data_dir('test')

        nc = self.data_cfg['nc']

        if 'names' in self.data_cfg:
            self.names = self.data_cfg.get('names')
        else:
            # assign class names if missing
            self.names = [f'class{i}' for i in range(self.data_cfg['nc'])]

        assert len(self.names) == nc, \
            f'{len(self.names)} names found for nc={nc} dataset in {yaml_path}'

        # 构建COCO格式目录
        self.dst = self.root_dir / f"{Path(self.root_data_dir).stem}_COCO_format"
        self.coco_train = "train2017"
        self.coco_val = "val2017"
        self.coco_test = "test2017"
        self.coco_annotation = "annotations"
        self.coco_train_json = self.dst / self.coco_annotation / \
            f'instances_{self.coco_train}.json'
        self.coco_val_json = self.dst / self.coco_annotation / \
            f'instances_{self.coco_val}.json'
        self.coco_test_json = self.dst / self.coco_annotation / \
            f'instances_{self.coco_test}.json'

        mkdir(self.dst)
        mkdir(self.dst / self.coco_train)
        mkdir(self.dst / self.coco_val)
        mkdir(self.dst / self.coco_test)
        mkdir(self.dst / self.coco_annotation)

        # 构建json内容结构
        self.type = 'instances'
        self.categories = []
        self._get_category()
        self.annotation_id = 1

        cur_year = time.strftime('%Y', time.localtime(time.time()))
        self.info = {
            'year': int(cur_year),
            'version': '1.0',
            'description': 'For object detection',
            'date_created': cur_year,
        }

        self.licenses = [{
            'id': 1,
            'name': 'Apache License v2.0',
            'url': 'https://github.com/RapidAI/YOLO2COCO/LICENSE',
        }]

    def _get_data_dir(self, mode):
        data_dir = self.data_cfg.get(mode)
        if data_dir:
            if isinstance(data_dir, str):
                full_path = [str(self.root_data_dir / data_dir)]
            elif isinstance(data_dir, list):
                full_path = [str(self.root_data_dir / one_dir)
                             for one_dir in data_dir]
            else:
                raise TypeError(f'{data_dir} is not str or list.')
        else:
            raise ValueError(f'{mode} dir is not in the yaml.')
        return full_path

    def _get_category(self):
        for i, category in enumerate(self.names, start=1):
            self.categories.append({
                'supercategory': category,
                'id': i,
                'name': category,
            })

    def generate(self):
        self.train_files = self.get_files(self.train_path)
        self.valid_files = self.get_files(self.val_path)
        self.test_files = self.get_files(self.test_path)

        train_dest_dir = Path(self.dst) / self.coco_train
        self.gen_dataset(self.train_files, train_dest_dir,
                         self.coco_train_json, mode='train')

        val_dest_dir = Path(self.dst) / self.coco_val
        self.gen_dataset(self.valid_files, val_dest_dir,
                         self.coco_val_json, mode='val')
        
        test_dest_dir = Path(self.dst) / self.coco_test
        self.gen_dataset(self.test_files, test_dest_dir,
                         self.coco_test_json, mode='test')

        print(f"The output directory is: {str(self.dst)}")

    def get_files(self, path):
        # include image suffixes
        IMG_FORMATS = ['bmp', 'dng', 'jpeg', 'jpg',
                       'mpo', 'png', 'tif', 'tiff', 'webp']
        f = []
        for p in path:
            p = Path(p)  # os-agnostic
            if p.is_dir():  # dir
                f += glob.glob(str(p / '**' / '*.*'), recursive=True)
                # f = list(p.rglob('*.*'))  # pathlib
            elif p.is_file():  # file
                with open(p) as t:
                    t = t.read().strip().splitlines()
                    parent = str(p.parent) + os.sep
                    # local to global path
                    f += [x.replace('./', parent)
                          if x.startswith('./') else x for x in t]
                    # f += [p.parent / x.lstrip(os.sep) for x in t]  # local to global path (pathlib)
            else:
                raise Exception(f'{p} does not exist')

        im_files = sorted(x.replace('/', os.sep)
                          for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
        return im_files

    def gen_dataset(self, img_paths, target_img_path, target_json, mode):
        """
        https://cocodataset.org/#format-data

        """
        images = []
        annotations = []
        sa, sb = os.sep + 'images' + os.sep, os.sep + \
            'labels' + os.sep  # /images/, /labels/ substrings

        for img_id, img_path in enumerate(tqdm(img_paths, desc=mode), 1):
            label_path = sb.join(img_path.rsplit(
                sa, 1)).rsplit('.', 1)[0] + '.txt'

            img_path = Path(img_path)

            verify_exists(img_path)

            imgsrc = cv2.imread(str(img_path))
            height, width = imgsrc.shape[:2]

            dest_file_name = f'{img_id:012d}.jpg'
            save_img_path = target_img_path / dest_file_name

            if img_path.suffix.lower() == ".jpg":
                shutil.copyfile(img_path, save_img_path)
            else:
                cv2.imwrite(str(save_img_path), imgsrc)

            images.append({
                'date_captured': '2021',
                'file_name': dest_file_name,
                'id': img_id,
                'height': height,
                'width': width,
            })

            if Path(label_path).exists():
                new_anno = self.read_annotation(label_path, img_id,
                                                height, width)
                if len(new_anno) > 0:
                    annotations.extend(new_anno)
                else:
                    # print(f'{label_path} is empty')
                    raise ValueError(f'{label_path} is empty')
            else:
                raise FileNotFoundError(f'{label_path} not exists')

        json_data = {
            'info': self.info,
            'images': images,
            'licenses': self.licenses,
            'type': self.type,
            'annotations': annotations,
            'categories': self.categories,
        }
        with open(target_json, 'w', encoding='utf-8') as f:
            json.dump(json_data, f, ensure_ascii=False)

    def read_annotation(self, txt_file, img_id, height, width):
        annotation = []
        all_info = read_txt(txt_file)
        for label_info in all_info:
            # 遍历一张图中不同标注对象
            label_info = label_info.split(" ")
            if len(label_info) < 5:
                continue

            category_id, vertex_info = label_info[0], label_info[1:]
            segmentation, bbox, area = self._get_annotation(vertex_info,
                                                            height, width)
            annotation.append({
                'segmentation': segmentation,
                'area': area,
                'iscrowd': 0,
                'image_id': img_id,
                'bbox': bbox,
                'category_id': int(category_id)+1,
                'id': self.annotation_id,
            })
            self.annotation_id += 1
        return annotation

    @staticmethod
    def _get_annotation(vertex_info, height, width):
        cx, cy, w, h = [float(i) for i in vertex_info]

        cx = cx * width
        cy = cy * height
        box_w = w * width
        box_h = h * height

        # left top
        x0 = max(cx - box_w / 2, 0)
        y0 = max(cy - box_h / 2, 0)

        # right bottomt
        x1 = min(x0 + box_w, width)
        y1 = min(y0 + box_h, height)

        segmentation = [[x0, y0, x1, y0, x1, y1, x0, y1]]
        bbox = [x0, y0, box_w, box_h]
        area = box_w * box_h
        return segmentation, bbox, area


if __name__ == "__main__":
    parser = argparse.ArgumentParser('Datasets converter from YOLOV5 to COCO')
    parser.add_argument('--yaml_path', type=str,
                        default='dataset/YOLOV5_yaml/sample.yaml',
                        help='Dataset cfg file')
    args = parser.parse_args()

    converter = YOLOV5CFG2COCO(args.yaml_path)
    converter.generate()

之后在终端输入:

python yolov5_yaml_2_coco.py --yaml_path dataset/YOLOV5_yaml/sample.yaml

这样就能看到测试集以及测试集标注信息json文件了:
在这里插入图片描述
在这里插入图片描述

大功告成,COCO格式数据集就创建成功了,还是很简单的。

总结

  • 安装mmdetection
  • 准备好已经标注完成的数据集
  • 看自己标签文件,使用对应的开源库
  • 如果和我数据集格式一样,请看注意,把作者的代码替换一下
下一篇将讲述如何使用本节划分完成的数据集训练模型
  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值