参考:https://mmdetection.readthedocs.io/zh-cn/v2.21.0/tutorials/customize_dataset.html
一般来说,新增了分类不生效可以检查一下几点:
样本方面:
- 数据集中是否有分类且分类id是否有标注,对于 COCO 数据集检查一下标注文件中的:categories 和 annotations.category_id 是否有对应的数值。用 grep -rn ““category_id”: 14” /path/to/train.json | more 脚本确认一下;
- 配置文件中检查 resume 取值和 [test_dataloader|val_dataloader|train_dataloader] 的 type, metainfo 等属性;
检查方法:
- 在训练有一个 epoch 之后就可以那到权重文件检查分类:torch.load(‘epoch_39.pth’)[‘meta’][‘dataset_meta’];
- 在 配置文件的增加 test_evaluator.classwise=True 配置,训练一个 epoch 之后的验证便会输出所有分类指标,也可以用来判断分类是否正确加入;
本项目自定义数据集的方式是:
- 在 mmdet.datasets. 包下新增了自定义数据集文件,继承 CocoDataset 并重写了 METAINFO 成员,成员中定义了 13 个检测分类;
- 在 mmdet.datasets.init.py 文件中导入自定义数据集类;
- 将配置文件 [test_dataloader|val_dataloader|train_dataloader].dataset.type 参数指定为自定义数据集类;
一共是 13 个类别。代码如下:
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.registry import DATASETS
import warnings
from .coco import CocoDataset
from .api_wrappers import COCO
@DATASETS.register_module()
class VOCSDTDKSHDataset(CocoDataset):
"""
*** 数据集, VOC 标注格式。
@auther chenghj11
"""
METAINFO = {
'classes': ('DaoXianYiWu'...), # 省略其他 12 个检测类别
'palette': [(106, 0, 228)...]
}
问题现象
新增了一个检测类别后,总的类别数量因为 14 类。重新训练模型后用 torch.load(‘/path/to/new/model’)[‘meta’][‘meta_class’] 还是只能获取到 13 个检测类别。
python -c "import torch; print(torch.load('epoch_39.pth')['meta']['dataset_meta'])"
解决方法
通过查看 mmengine 源码 mmengine/dataset/base_dataset.py:BaseDataset 得知数据集定义中有一个 metainfo 的参数。所以只需要在配置文件中数据集相关的配置上加上 metainfo 参数即可。
步骤如下:
- 在配置文件,增加 metainfo = {‘classes’: (‘DaoXianYiWu’, …), ‘palette’: ((106, 0, 228), …)}, num_classes = len(metainfo[‘classes’]) 两个变量定义;
- 在 model.roi_head.bbox_head 中增加 num_classes=num_classes, 定义;
- 在 test_dataloader.dataset 中增加 metainfo=metainfo, 定义;
- 在 val_dataloader.dataset 中增加 metainfo=metainfo, 定义;
- 在 train_dataloader.dataset 中增加 metainfo=metainfo, 定义;
- 在配置文件中修改 resume 配置: resume=False;
调试确认在 mmdet.datasets.coco.CocoDataset.load_data_list 方法中能得到正确的分类数量:
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
""" # noqa: E501
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
self.coco = self.COCOAPI(local_path)
# The order of returned `cat_ids` will not
# change with the order of the `classes`
self.cat_ids = self.coco.get_cat_ids(
cat_names=self.metainfo['classes'])
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.cat_img_map = copy.deepcopy(self.coco.cat_img_map)