DEFRCN训练自己的数据集

DEFRCN训练自己的数据集

问题描述

博主最近有一个数据集,想用来做个小样本检测写篇小论文,决定用DEFRCN做baseline,DEFRCN复现的博客见我的另一篇博客链接: DEFRCN代码复现记录。我想达到的效果就是使用coco数据集作为base类,而我自定义的数据集作为novel类(我的数据集7类)。记录一下实现的过程。

数据集准备

coco数据集制作

想要实现这个目标coco数据集比voc数据集要简单一点(我是这么觉得),所以要先把数据集转化为coco数据集,这个过程教程挺多,就不多赘述了。

合并数据集标注文件

要实现使用coco数据集作为base类,并在自定义数据集做微调,经过多次尝试,发现最方便的方法还是将自己的数据集与coco2014数据集进行合并。下面来讲讲合并的过程。
下面这是coco数据集合并json文件的脚本 combine.py

import json

json1_file = 'coco数据集原来的标注文件'
json2_file = '自己的coco数据集标注文件'
# 读取原始的COCO标注文件
with open(json1_file, 'r') as f:
    coco_data = json.load(f)

# 读取您的标注文件
with open(json2_file, 'r') as f:
    your_data = json.load(f)

# 将您的标注文件内容添加到原始的COCO标注文件中
coco_data['images'].extend(your_data['images'])
coco_data['annotations'].extend(your_data['annotations'])
coco_data['categories'].extend(your_data['categories'])

# 保存更新后的COCO标注文件
with open(json1_file, 'w') as f:
    json.dump(coco_data, f)

使用上面的脚本分别合并自己的数据集以及coco数据集的val.json和val.json文件,DEFRCN源码中的还需要用到cocosplit,其中有两个json文件,分别是trainvalno5k.json和5k.json。cocosplit的说明看这链接: cocosplit。根据我的理解,trainvalno5k.json类似于训练集的分割,5k.json是验证集的分割。我们要做的操作就是将自己数据集的训练集的标注文件与trainvalno5k.json合并,并用自己的验证集替代5k.json。合并后需要使用脚本修改一下类别id,图片id以及标注框id等数据避免与coco数据集原有的数据集重复冲突。

生成自己的seeds

为了进行不同shots的小样本训练,需要使用脚本分割训练集。脚本源码在这链接: prepare_coco_few_shot.py。懒得看源码可以直接复制以下代码 prepare_coco_few_shot.py

import argparse
import json
import os
import random


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--seeds", type=int, nargs="+", default=[1, 10], help="Range of seeds"
    )
    args = parser.parse_args()
    return args


def generate_seeds(args):
    data_path = "datasets/cocosplit/datasplit/trainvalno5k.json"
    data = json.load(open(data_path))

    new_all_cats = []
    for cat in data["categories"]:
        new_all_cats.append(cat)

    id2img = {}
    for i in data["images"]:
        id2img[i["id"]] = i

    anno = {i: [] for i in ID2CLASS.keys()}
    for a in data["annotations"]:
        if a["iscrowd"] == 1:
            continue
        anno[a["category_id"]].append(a)

    for i in range(args.seeds[0], args.seeds[1]):
        random.seed(i)
        for c in ID2CLASS.keys():
            img_ids = {}
            for a in anno[c]:
                if a["image_id"] in img_ids:
                    img_ids[a["image_id"]].append(a)
                else:
                    img_ids[a["image_id"]] = [a]

            sample_shots = []
            sample_imgs = []
            for shots in [1, 2, 3, 5, 10, 30]:
                while True:
                    imgs = random.sample(list(img_ids.keys()), shots)
                    for img in imgs:
                        skip = False
                        for s in sample_shots:
                            if img == s["image_id"]:
                                skip = True
                                break
                        if skip:
                            continue
                        if len(img_ids[img]) + len(sample_shots) > shots:
                            continue
                        sample_shots.extend(img_ids[img])
                        sample_imgs.append(id2img[img])
                        if len(sample_shots) == shots:
                            break
                    if len(sample_shots) == shots:
                        break
                new_data = {
                    "info": data["info"],
                    "licenses": data["licenses"],
                    "images": sample_imgs,
                    "annotations": sample_shots,
                }
                save_path = get_save_path_seeds(
                    data_path, ID2CLASS[c], shots, i
                )
                new_data["categories"] = new_all_cats
                with open(save_path, "w") as f:
                    json.dump(new_data, f)


def get_save_path_seeds(path, cls, shots, seed):
    prefix = "full_box_{}shot_{}_trainval".format(shots, cls)
    save_dir = os.path.join("datasets", "cocosplit", "seed" + str(seed))
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, prefix + ".json")
    return save_path


if __name__ == "__main__":
    ID2CLASS = {
        1: "person",
        2: "bicycle",
        3: "car",
        4: "motorcycle",
        5: "airplane",
        6: "bus",
        7: "train",
        8: "truck",
        9: "boat",
        10: "traffic light",
        11: "fire hydrant",
        13: "stop sign",
        14: "parking meter",
        15: "bench",
        16: "bird",
        17: "cat",
        18: "dog",
        19: "horse",
        20: "sheep",
        21: "cow",
        22: "elephant",
        23: "bear",
        24: "zebra",
        25: "giraffe",
        27: "backpack",
        28: "umbrella",
        31: "handbag",
        32: "tie",
        33: "suitcase",
        34: "frisbee",
        35: "skis",
        36: "snowboard",
        37: "sports ball",
        38: "kite",
        39: "baseball bat",
        40: "baseball glove",
        41: "skateboard",
        42: "surfboard",
        43: "tennis racket",
        44: "bottle",
        46: "wine glass",
        47: "cup",
        48: "fork",
        49: "knife",
        50: "spoon",
        51: "bowl",
        52: "banana",
        53: "apple",
        54: "sandwich",
        55: "orange",
        56: "broccoli",
        57: "carrot",
        58: "hot dog",
        59: "pizza",
        60: "donut",
        61: "cake",
        62: "chair",
        63: "couch",
        64: "potted plant",
        65: "bed",
        67: "dining table",
        70: "toilet",
        72: "tv",
        73: "laptop",
        74: "mouse",
        75: "remote",
        76: "keyboard",
        77: "cell phone",
        78: "microwave",
        79: "oven",
        80: "toaster",
        81: "sink",
        82: "refrigerator",
        84: "book",
        85: "clock",
        86: "vase",
        87: "scissors",
        88: "teddy bear",
        89: "hair drier",
        90: "toothbrush",
    }
    CLASS2ID = {v: k for k, v in ID2CLASS.items()}

    args = parse_args()
    generate_seeds(args)

我们要做的修改就是将ID2CLASS修改为自己的类别,这个id要和标注的类别id需要对应。使用这个脚本对自己数据集的训练集进行分割得到seed1-9的文件夹,将这个文件夹替换cocosplit中的对应内容。

代码修改

建议把修改整理得到的coco2014的数据集一样的名字并且按照源码的位置放好(记得把自己训练集和验证集的图片加到coco数据集的对对应文件夹中哈),这样会方便很多。同时把val2014的图片移动到train2014中,并将train2014文件夹修改为trainval2014(源码是这样的)。

defrcn/data/builtin_meta.py修改内容

novel类的修改

在这里插入图片描述
在这里插入图片描述

configs文件修改

对于base.yaml修改类别60为80(因为把原来的novel添加到了base类)
在这里插入图片描述
在few-shot的yaml文件中将类别数修改为自己的类别数
在这里插入图片描述
然后就可以运行了

结语

复现过程中可能还有些问题,有问题可以留言,看到会及时回复。

评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值