mmdetection 自定义数据集新增了检测类别不生效的问题

参考:https://mmdetection.readthedocs.io/zh-cn/v2.21.0/tutorials/customize_dataset.html
一般来说,新增了分类不生效可以检查一下几点:
样本方面:

  1. 数据集中是否有分类且分类id是否有标注,对于 COCO 数据集检查一下标注文件中的:categories 和 annotations.category_id 是否有对应的数值。用 grep -rn ““category_id”: 14” /path/to/train.json | more 脚本确认一下;
  2. 配置文件中检查 resume 取值和 [test_dataloader|val_dataloader|train_dataloader] 的 type, metainfo 等属性;

检查方法:

  1. 在训练有一个 epoch 之后就可以那到权重文件检查分类:torch.load(‘epoch_39.pth’)[‘meta’][‘dataset_meta’];
  2. 在 配置文件的增加 test_evaluator.classwise=True 配置,训练一个 epoch 之后的验证便会输出所有分类指标,也可以用来判断分类是否正确加入;

本项目自定义数据集的方式是:

  1. 在 mmdet.datasets. 包下新增了自定义数据集文件,继承 CocoDataset 并重写了 METAINFO 成员,成员中定义了 13 个检测分类;
  2. 在 mmdet.datasets.init.py 文件中导入自定义数据集类;
  3. 将配置文件 [test_dataloader|val_dataloader|train_dataloader].dataset.type 参数指定为自定义数据集类;
    一共是 13 个类别。代码如下:
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.registry import DATASETS
import warnings
from .coco import CocoDataset
from .api_wrappers import COCO


@DATASETS.register_module()
class VOCSDTDKSHDataset(CocoDataset):
    """
    *** 数据集, VOC 标注格式。
    @auther chenghj11
    """
    METAINFO = {
        'classes': ('DaoXianYiWu'...), # 省略其他 12 个检测类别
        'palette': [(106, 0, 228)...]
    }

问题现象

新增了一个检测类别后,总的类别数量因为 14 类。重新训练模型后用 torch.load(‘/path/to/new/model’)[‘meta’][‘meta_class’] 还是只能获取到 13 个检测类别。

python -c "import torch; print(torch.load('epoch_39.pth')['meta']['dataset_meta'])"

解决方法

通过查看 mmengine 源码 mmengine/dataset/base_dataset.py:BaseDataset 得知数据集定义中有一个 metainfo 的参数。所以只需要在配置文件中数据集相关的配置上加上 metainfo 参数即可。
metainfo 参数
步骤如下:

  1. 在配置文件,增加 metainfo = {‘classes’: (‘DaoXianYiWu’, …), ‘palette’: ((106, 0, 228), …)}, num_classes = len(metainfo[‘classes’]) 两个变量定义;
  2. 在 model.roi_head.bbox_head 中增加 num_classes=num_classes, 定义;
  3. 在 test_dataloader.dataset 中增加 metainfo=metainfo, 定义;
  4. 在 val_dataloader.dataset 中增加 metainfo=metainfo, 定义;
  5. 在 train_dataloader.dataset 中增加 metainfo=metainfo, 定义;
  6. 在配置文件中修改 resume 配置: resume=False

调试确认在 mmdet.datasets.coco.CocoDataset.load_data_list 方法中能得到正确的分类数量:

    def load_data_list(self) -> List[dict]:
        """Load annotations from an annotation file named as ``self.ann_file``

        Returns:
            List[dict]: A list of annotation.
        """  # noqa: E501
        with get_local_path(
                self.ann_file, backend_args=self.backend_args) as local_path:
            self.coco = self.COCOAPI(local_path)
        # The order of returned `cat_ids` will not
        # change with the order of the `classes`
        self.cat_ids = self.coco.get_cat_ids(
            cat_names=self.metainfo['classes'])
        self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
        self.cat_img_map = copy.deepcopy(self.coco.cat_img_map)

metainfo 定义示例

  • 4
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值