mmdetection之dataset类解读


前言

 本篇是mmdetection源码解读第二篇,主要讲解mmdetection是初始化数据类的。本文以coco数据集为例,当然,源码解读不可能面面俱到,重要的是揣摩设计者的思想以及实现过程。另外,本文先暂时不予介绍dataloader构建过程。


1、总体流程

 通常我们利用pytorch读取数据集需要构建两个部分,一个是数据集初始化,主要完成数据集的存储路径;一个是实现getitem方法,变成迭代器来训练模型:
在这里插入图片描述
 这里解释下pipline。在mmdetection中,pipline实际上是一系列顺序的关于图像读取,增强,合并的函数。即实例了一个图像增强对象,之后在getitem中利用transforms对data进行增强。这里简单有个理解即可。后续我会详细介绍。

2、实现流程

2.1. coco_detection训练配置文件

 截取mmdetection中用于train的训练集的配置文件。代码:configs/_base_/datasets/coco_detection.。

dataset_type = 'CocoDataset'
data_root = '/home/wujian/WLL/mmdet-master/data/coco/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='Resize', img_scale=(800, 512), keep_ratio=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]

 上述配置文件中,dataset_type表示读取coco格式的数据集。data_root是数据集存储路径。train_pipline用于图像增强函数的参数文件。

2.2. CocoDataset初始化

 mmdetection中使用build_dataset函数来完成dataset实例化。

datasets = [build_dataset(cfg.data.train)]

  这里内部build_dataset实质上内部调用了build_from_cfg函数(这一块我不介绍了,要不太冗余了,主要理解设计思想),这个函数将cfg文件用于CocoDataset类初始化,而CocoDataset类继承自CustomDataset类,我主要截取重要部分,地址:mmdet/datasets/custom.py。

@DATASETS.register_module()
class CustomDataset(Dataset):
    CLASSES = None
    def __init__(self,
                 ann_file,
                 pipeline,
                 classes=None,
                 data_root=None,
                 img_prefix='',
                 seg_prefix=None,
                 proposal_file=None,
                 test_mode=False,
                 filter_empty_gt=True):
        self.ann_file = ann_file
        self.data_root = data_root
        self.img_prefix = img_prefix
        self.seg_prefix = seg_prefix
        self.proposal_file = proposal_file
        self.test_mode = test_mode
        self.filter_empty_gt = filter_empty_gt
        self.CLASSES = self.get_classes(classes)
        # load annotations (and proposals)
        self.data_infos = self.load_annotations(self.ann_file)

        # processing pipeline
        self.pipeline = Compose(pipeline)

 这里初始化了data_root,值得注意的是最后一行self.pipline = Compose(pipline),这就是第一部分实例化了一个图像增强的类。
 我们看下Compose类:

@PIPELINES.register_module()
class Compose(object):
    def __init__(self, transforms):
        assert isinstance(transforms, collections.abc.Sequence)
        self.transforms = []    # transforms即传入的一个长度为8,且每个元素是字典的list。[{'type':'LoadImageFromFile'}]
        for transform in transforms:
            if isinstance(transform, dict):
                transform = build_from_cfg(transform, PIPELINES)
                self.transforms.append(transform)

 参数transfoms是个长度为8的list,各个元素是字典,字典的内容就是train_pipline中内容。举个例子:
trainsfoms=[{‘type’:LoadImageFromFile},{‘type’:LoadAnnotations}。在Compose初始化中,通过遍历transforms里面的8个元素,利用build_from_cfg函数完成了各个类的实例化,之后将各个实例对象append进self.transforms列表中。至此,Compose类实际上里面存储的是顺序的图像增强实例对象。至此,CocoDataset初始化部分完成。

2.3. CocoDataset中getitem实现

 放下getitem函数,依旧在CustomDataset类内:

    def __getitem__(self, idx):
        if self.test_mode:
            return self.prepare_test_img(idx)
        while True:
            data = self.prepare_train_img(idx)
            if data is None:
                idx = self._rand_another(idx)   # 这里写的鲁棒,若idx失效,则随机读取另一张图像
                continue
            return data
            
    def prepare_train_img(self, idx):
        img_info = self.data_infos[idx]
        ann_info = self.get_ann_info(idx)
        results = dict(img_info=img_info, ann_info=ann_info)
        if self.proposals is not None:
            results['proposals'] = self.proposals[idx]
        self.pre_pipeline(results)
        return self.pipeline(results)

 从函数可以看出:首先借助idx读取data,然后利用prepare_train_img完成data的图像增强。之后,return data。

总结

 以上就是mmdetection中dataset类实例过程。首先初始化路径以及完成图像增强pipline的实例。然后完成getitem函数。若有问题欢迎+vx:wulele2541612007,拉你进群探讨交流。

  • 10
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
mmdetection是一个基于PyTorch的目标检测框架,其配置文件是控制模型训练、测试和推理的重要参数。下面是一个mmdetection配置文件的解读: ```python # model settings model = dict( type='RetinaNet', pretrained='torchvision://resnet50', backbone=dict( type='ResNet', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), frozen_stages=1, norm_cfg=dict(type='BN', requires_grad=True), norm_eval=True, style='pytorch'), neck=dict( type='FPN', in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5), bbox_head=dict( type='RetinaHead', num_classes=80, in_channels=256, stacked_convs=4, feat_channels=256, octave_base_scale=4, scales_per_octave=3, anchor_ratios=[0.5, 1.0, 2.0], anchor_strides=[8, 16, 32, 64, 128], target_means=[.0, .0, .0, .0], target_stds=[1.0, 1.0, 1.0, 1.0], loss_cls=dict( type='FocalLoss', use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=1.0), loss_bbox=dict(type='SmoothL1Loss', beta=0.11, loss_weight=1.0)), # training and testing settings train_cfg=dict( assigner=dict( type='MaxIoUAssigner', pos_iou_thr=0.5, neg_iou_thr=0.4, min_pos_iou=0, ignore_iof_thr=-1), smoothl1_beta=1.0, allowed_border=-1, pos_weight=-1, debug=False), test_cfg=dict( nms_pre=1000, min_bbox_size=0, score_thr=0.05, nms=dict(type='nms', iou_threshold=0.5), max_per_img=100)) # dataset settings dataset_type = 'CocoDataset' data_root = 'data/coco/' img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', with_bbox=True), dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), dict(type='RandomFlip', flip_ratio=0.5), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size_divisor=32), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), ] test_pipeline = [ dict(type='LoadImageFromFile'), dict( type='MultiScaleFlipAug', img_scale=(1333, 800), flip=False, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size_divisor=32), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']), ]) ] data = dict( samples_per_gpu=2, workers_per_gpu=2, train=dict( type=dataset_type, ann_file=data_root + 'annotations/instances_train2017.json', img_prefix=data_root + 'train2017/', pipeline=train_pipeline), val=dict( type=dataset_type, ann_file=data_root + 'annotations/instances_val2017.json', img_prefix=data_root + 'val2017/', pipeline=test_pipeline), test=dict( type=dataset_type, ann_file=data_root + 'annotations/instances_val2017.json', img_prefix=data_root + 'val2017/', pipeline=test_pipeline)) evaluation = dict(interval=1, metric='bbox') # optimizer optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) # learning policy lr_config = dict( policy='step', warmup='linear', warmup_iters=500, warmup_ratio=0.001, step=[8, 11]) total_epochs = 12 # checkpoints checkpoint_config = dict(interval=1) log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')]) # runtime settings dist_params = dict(backend='nccl') log_level = 'INFO' work_dir = './work_dirs/retinanet_r50_fpn_1x' load_from = None resume_from = None workflow = [('train', 1)] ``` 上述配置文件的主要部分包括: 1. `model`:模型设置,包括模型型、预训练模型、骨干网络、neck、bbox_head等。 2. `dataset`:数据集设置,包括数据集型、数据集路径、数据预处理管道等。 3. `optimizer`:优化器设置,包括优化器型、学习率、动量、权重衰减等。 4. `lr_config`:学习率调整策略,包括学习率策略、热身策略、步数和对应学习率等。 5. `total_epochs`:训练总轮数。 6. `checkpoint_config`:保存模型检查点的间隔。 7. `log_config`:日志设置,包括日志输出间隔和日志输出方式等。 8. `dist_params`:分布式参数设置,包括分布式后端等。 9. `work_dir`:训练、测试和推理结果保存路径。 10. `load_from`和`resume_from`:模型加载和恢复路径。 11. `workflow`:训练、测试和推理流程,包括每个阶段的GPU数量。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值