前言
本篇是mmdetection源码解读第二篇,主要讲解mmdetection是初始化数据类的。本文以coco数据集为例,当然,源码解读不可能面面俱到,重要的是揣摩设计者的思想以及实现过程。另外,本文先暂时不予介绍dataloader构建过程。
1、总体流程
通常我们利用pytorch读取数据集需要构建两个部分,一个是数据集初始化,主要完成数据集的存储路径;一个是实现getitem方法,变成迭代器来训练模型:
这里解释下pipline。在mmdetection中,pipline实际上是一系列顺序的关于图像读取,增强,合并的函数。即实例了一个图像增强对象,之后在getitem中利用transforms对data进行增强。这里简单有个理解即可。后续我会详细介绍。
2、实现流程
2.1. coco_detection训练配置文件
截取mmdetection中用于train的训练集的配置文件。代码:configs/_base_/datasets/coco_detection.。
dataset_type = 'CocoDataset'
data_root = '/home/wujian/WLL/mmdet-master/data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(800, 512), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
上述配置文件中,dataset_type表示读取coco格式的数据集。data_root是数据集存储路径。train_pipline用于图像增强函数的参数文件。
2.2. CocoDataset初始化
mmdetection中使用build_dataset函数来完成dataset实例化。
datasets = [build_dataset(cfg.data.train)]
这里内部build_dataset实质上内部调用了build_from_cfg函数(这一块我不介绍了,要不太冗余了,主要理解设计思想),这个函数将cfg文件用于CocoDataset类初始化,而CocoDataset类继承自CustomDataset类,我主要截取重要部分,地址:mmdet/datasets/custom.py。
@DATASETS.register_module()
class CustomDataset(Dataset):
CLASSES = None
def __init__(self,
ann_file,
pipeline,
classes=None,
data_root=None,
img_prefix='',
seg_prefix=None,
proposal_file=None,
test_mode=False,
filter_empty_gt=True):
self.ann_file = ann_file
self.data_root = data_root
self.img_prefix = img_prefix
self.seg_prefix = seg_prefix
self.proposal_file = proposal_file
self.test_mode = test_mode
self.filter_empty_gt = filter_empty_gt
self.CLASSES = self.get_classes(classes)
# load annotations (and proposals)
self.data_infos = self.load_annotations(self.ann_file)
# processing pipeline
self.pipeline = Compose(pipeline)
这里初始化了data_root,值得注意的是最后一行self.pipline = Compose(pipline),这就是第一部分实例化了一个图像增强的类。
我们看下Compose类:
@PIPELINES.register_module()
class Compose(object):
def __init__(self, transforms):
assert isinstance(transforms, collections.abc.Sequence)
self.transforms = [] # transforms即传入的一个长度为8,且每个元素是字典的list。[{'type':'LoadImageFromFile'}]
for transform in transforms:
if isinstance(transform, dict):
transform = build_from_cfg(transform, PIPELINES)
self.transforms.append(transform)
参数transfoms是个长度为8的list,各个元素是字典,字典的内容就是train_pipline中内容。举个例子:
trainsfoms=[{‘type’:LoadImageFromFile},{‘type’:LoadAnnotations}。在Compose初始化中,通过遍历transforms里面的8个元素,利用build_from_cfg函数完成了各个类的实例化,之后将各个实例对象append进self.transforms列表中。至此,Compose类实际上里面存储的是顺序的图像增强实例对象。至此,CocoDataset初始化部分完成。
2.3. CocoDataset中getitem实现
放下getitem函数,依旧在CustomDataset类内:
def __getitem__(self, idx):
if self.test_mode:
return self.prepare_test_img(idx)
while True:
data = self.prepare_train_img(idx)
if data is None:
idx = self._rand_another(idx) # 这里写的鲁棒,若idx失效,则随机读取另一张图像
continue
return data
def prepare_train_img(self, idx):
img_info = self.data_infos[idx]
ann_info = self.get_ann_info(idx)
results = dict(img_info=img_info, ann_info=ann_info)
if self.proposals is not None:
results['proposals'] = self.proposals[idx]
self.pre_pipeline(results)
return self.pipeline(results)
从函数可以看出:首先借助idx读取data,然后利用prepare_train_img完成data的图像增强。之后,return data。
总结
以上就是mmdetection中dataset类实例过程。首先初始化路径以及完成图像增强pipline的实例。然后完成getitem函数。