网上看了很多博主写的代码,写的很好,就是看起来乱乱的,所以自己看了下官网的代码,下面这个方法比较简洁一点
一、这是需要添加的代码
from detectron2.data.datasets import register_coco_instances
register_coco_instances("my_dataset_train_name", {}, "path/to/json_annotation_train.json", "path/to/image/dir")
register_coco_instances("my_dataset_val_name", {}, "path/to/json_annotation_val.json", "path/to/image/dir")
二、如何添加
1.准备自己的数据集
要确保自己的数据集的格式是coco的,不然这个方法会报错,此方法只适用于coco格式的个人数据集。
2.找到 tools/train_net.py 文件
(不同的代码可能名称不太一样,确保该文件中定义了class Trainer(DefaultTrainer),setup()和start_train())
3.在上述的文件起始部分添加代码(自己修改相应的路径)
from detectron2.data.datasets import register_coco_instances
register_coco_instances("my_dataset_train_name", {}, "path/to/json_annotation_train.json", "path/to/image/dir")
register_coco_instances("my_dataset_val_name", {}, "path/to/json_annotation_val.json", "path/to/image/dir")
4.找到 BaseRetina.yaml ,修改 DATASETS 的声明,改成自己数据集的名称
DATASETS:
TRAIN: ("my_dataset_train_name",)
TEST: ("my_dataset_val_name",)
完成修改!
注意:自己数据集如果与coco格式不一致,会有报错