DETR训练自己的数据集,yolo数据集格式转为coco数据集格式

一、数据集准备

1.1 DETR数据格式

|--- dataset
	|--- train
		|--- 1.jpg
		|--- 2.jpg
	|--- val
		|--- 1.jpg
		|--- 2.jpg
	|--- annotations
		|--- instances_train.json
		|--- instances_val.json

其中 instances_train.json 和 instances_val.json 中记录了图片标注信息,例如:

{
	"images": [
	{"file_name":"/home/shares/detr/datasets/val_xml/208.xml", 
	"height": 720, "width": 1280, "id": 208}, 
	{"file_name":"/home/shares/detr/datasets/val_xml/468.xml", 
	"height": 720, "width": 1280, "id": 468},
	...
	]
	 "type": "instances",
	 "annotations": [
	 {"area": 14151, "iscrowd": 0, "image_id": 208, "bbox": [360, 521, 159, 89], "category_id": 3, "id": 1, "ignore": 0, "segmentation": []},
	 {"area": 21890, "iscrowd": 0, "image_id": 468, "bbox": [209, 382, 110, 199], "category_id": 2, "id": 2, "ignore": 0, "segmentation": []},
	...
	]
	"categories": [
	{"supercategory": "none", "id": 1, "name": "cigaretteface"}, 
	{"supercategory": "none", "id": 2, "name": "smokeface"}, 
	{"supercategory": "none", "id": 3, "name": "normalface"}, 
	{"supercategory": "none", "id": 4, "name": "callface"}
	]
}

1.2 yolo数据集格式转为coco数据集格式

若已有coco数据集格式则跳过此步骤

1.2.1 yolo数据集格式如下:

|--- /home/shares/datasets/my_voc_dataset
	|--- Annotations
		|--- 1.xml
		|--- 2.xml
	|--- ImageSets
		|--- Main
			|--- train.txt
			|--- val.txt
	|--- JPEGImages
		|--- 1.jpg
		|--- 2.jpg

1.2.2 使用下述代码可将此种结构目录转换为coco格式:

import os
import json
import glob
import shutil
import xml.etree.ElementTree as ET

# 定义标签从1开始编号并给定类别对应数字标签
START_BOUNDING_BOX_ID = 1
PRE_DEFINE_CATEGORIES = {"cigaretteface": 1, "smokeface": 2, "normalface": 3, "callface": 4}

def get(root, name):
    vars = root.findall(name)
    return vars

def get_and_check(root, name, length):
    vars = root.findall(name)
    if len(vars) == 0:
        raise ValueError("Can not find %s in %s." % (name, root.tag))
    if length > 0 and len(vars) != length:
        raise ValueError(
            "The size of %s is supposed to be %d, but is %d."
            % (name, length, len(vars))
        )
    if length == 1:
        vars = vars[0]
    return vars

def get_filename_as_int(filename):
    try:
        filename = filename.replace("\\", "/")
        filename = os.path.splitext(os.path.basename(filename))[0]
        return int(filename)
    except:
        raise ValueError(
            "Filename %s is supposed to be an integer." % (filename))

def get_categories(xml_files):
    """Generate category name to id mapping from a list of xml files.

    Arguments:
        xml_files {list} -- A list of xml file paths.

    Returns:
        dict -- category name to id mapping.
    """
    classes_names = []
    for xml_file in xml_files:
        tree = ET.parse(xml_file)
        root = tree.getroot()
        for member in root.findall("object"):
            classes_names.append(member[0].text)
    classes_names = list(set(classes_names))
    classes_names.sort()
    return {name: i for i, name in enumerate(classes_names)}

def convert(xml_files, json_file):
    json_dict = {"images": [], "type": "instances",
                 "annotations": [], "categories": []}
    if PRE_DEFINE_CATEGORIES is not None:
        categories = PRE_DEFINE_CATEGORIES
    else:
        categories = get_categories(xml_files)
    bnd_id = START_BOUNDING_BOX_ID
    nums = len(xml_files)
    i = 1
    for xml_file in xml_files:
        print('\r converting xml to json : {}/{}'.format(i, nums), end = "")
        i += 1
        tree = ET.parse(xml_file)
        root = tree.getroot()
        path = get(root, "path")

        # The filename must be a number
        image_id = get_filename_as_int(xml_file)
        size = get_and_check(root, "size", 1)
        width = int(get_and_check(size, "width", 1).text)
        height = int(get_and_check(size, "height", 1).text)
        image = {
            "file_name": xml_file,
            "height": height,
            "width": width,
            "id": image_id,
        }
        json_dict["images"].append(image)
        # Currently we do not support segmentation.
        #  segmented = get_and_check(root, 'segmented', 1).text
        #  assert segmented == '0'
        for obj in get(root, "object"):
            category = get_and_check(obj, "name", 1).text
            if category not in categories:
                new_id = len(categories)
                categories[category] = new_id
            category_id = categories[category]
            bndbox = get_and_check(obj, "bndbox", 1)
            xmin = int(get_and_check(bndbox, "xmin", 1).text) - 1
            ymin = int(get_and_check(bndbox, "ymin", 1).text) - 1
            xmax = int(get_and_check(bndbox, "xmax", 1).text)
            ymax = int(get_and_check(bndbox, "ymax", 1).text)
            assert xmax > xmin
            assert ymax > ymin
            o_width = abs(xmax - xmin)
            o_height = abs(ymax - ymin)
            ann = {
                "area": o_width * o_height,
                "iscrowd": 0,
                "image_id": image_id,
                "bbox": [xmin, ymin, o_width, o_height],
                "category_id": category_id,
                "id": bnd_id,
                "ignore": 0,
                "segmentation": [],
            }
            json_dict["annotations"].append(ann)
            bnd_id = bnd_id + 1
    print()
    for cate, cid in categories.items():
        cat = {"supercategory": "none", "id": cid, "name": cate}
        json_dict["categories"].append(cat)

    os.makedirs(os.path.dirname(json_file), exist_ok=True)
    json_fp = open(json_file, "w")
    json_str = json.dumps(json_dict)
    json_fp.write(json_str)
    json_fp.close()


if __name__ == "__main__":
    voc_path = "/home/shares/datasets/my_voc_dataset"
    
    #  保存coco格式数据集根目录
    save_coco_path = "/home/shares/detr/datasets"
    
    #  VOC只分了训练集和验证集即train.txt和val.txt
    data_type_list = ["train", "val"]
    for data_type in data_type_list:
        try:
            os.makedirs(os.path.join(save_coco_path, data_type))
            os.makedirs(os.path.join(save_coco_path, data_type+"_xml"))
            with open(os.path.join(voc_path, "ImageSets"+os.sep+"Main", data_type+".txt"), "r") as f:
                txt_ls = f.readlines()
            txt_ls = [i.strip() for i in txt_ls]
            idx = 0
            for i in os.listdir(os.path.join(voc_path, "JPEGImages")):
                print('\rcopying imgs', end = "")
                if os.path.splitext(i)[0] in txt_ls:
                    shutil.copy(os.path.join(voc_path, "JPEGImages", i),
                                os.path.join(save_coco_path, data_type, str(idx) + ".jpg"))
                    shutil.copy(os.path.join(voc_path, "Annotations", i[:-4]+".xml"), os.path.join(
                        save_coco_path, data_type+"_xml", str(idx)+".xml"))
                    idx += 1
        except:
            print("sdfsf")
        xml_path = os.path.join(save_coco_path, data_type+"_xml")
        xml_files = glob.glob(os.path.join(xml_path, "*.xml"))
        convert(xml_files, os.path.join(save_coco_path,
                "annotations", "instances_"+data_type+".json"))
        shutil.rmtree(xml_path)

执行完上述代码后,会自动创建路径
a. save_coco_path
b. save_coco_path/train
c. save_coco_path/val
且train、val下的图片数据从0开始编号

二、修改训练参数

进入detr/main.py编辑get_args_parser()

# dataset parameters
    parser.add_argument('--coco_path', type=str)
    parser.add_argument('--coco_panoptic_path', type=str)
	...
    parser.add_argument('--output_dir', default='',
                        help='path where to save, empty for no saving')

coco_path:数据集路径,改为上述 save_coco_path 路径
output_dir:训练结果保存路径,eg:runs/result1
在对应参数设置里加入default关键字并赋值,修改后为:

# dataset parameters
    parser.add_argument('--coco_path', type=str, default = "/home/shares/detr/datasets")
    parser.add_argument('--coco_panoptic_path', type=str)
	...
    parser.add_argument('--output_dir', default="/home/shares/detr/runs/result1",
                        help='path where to save, empty for no saving')

三、训练

在终端输入命令

python main.py

运行后终端开始打印结果
在这里插入图片描述

四、训练结果

在runs/result1目录下会生成以下文件/文件夹
在这里插入图片描述
其中,checkpoint.pth即为训练完成权重;log.txt为训练记录内容,

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值