【mmdetection】中dataloader加载COCO数据集
时间:2022年9月10日
平常调试mmdetection代码,需要载入部分数据,所以写个脚本,方便数据加载。
参考:
datasets的构建参考./tools/train.py
data_loaders的构建参考./mmdet/apis/train.py
注意:
datasets载入的数据就已经是数据增强后的了,已经是经过缩放、翻转、正则化、填充后的了
代码如下:
from mmdet.datasets import build_dataset, build_dataloader
from mmcv import Config
from mmdet.apis import init_random_seed, set_random_seed
import torch.distributed as dist
import argparse
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('--config', default='./configs/deformable_detr/deformable_detr_r50_16x2_50e_coco.py', help='train config file path')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--diff-seed',
action='store_true',
help='Whether or not set different seeds for different ranks')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument(
'--gpus',
type=int,
help='(Deprecated, please use --gpu-id) number of gpus to use '
'(only applicable to non-distributed training)')
group_gpus.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='(Deprecated, please use --gpu-id) ids of gpus to use '
'(only applicable to non-distributed training)')
group_gpus.add_argument(
'--gpu-id',
type=int,
default=0,
help='id of gpu to use '
'(only applicable to non-distributed training)')
args = parser.parse_args()
return args
def get_data():
args = parse_args()
cfg = Config.fromfile(args.config)
datasets = build_dataset(cfg.data.train)
print(f'datasets build finish! total : {len(datasets)}')
# set random seeds
seed = init_random_seed(args.seed)
seed = seed + dist.get_rank() if args.diff_seed else seed
set_random_seed(seed, deterministic=args.deterministic)
cfg.seed = seed
# set gpu_ids
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids
else:
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
# set runner_type
runner_type = 'EpochBasedRunner' if 'runner' not in cfg else cfg.runner[
'type']
data_loaders = build_dataloader(
datasets,
samples_per_gpu=2,
workers_per_gpu=2,
# `num_gpus` will be ignored if distributed
num_gpus=len(cfg.gpu_ids),
dist=False,
seed=cfg.seed,
runner_type=runner_type,
persistent_workers=cfg.data.get('persistent_workers', False))
print(f'data_loaders build finish! total : {len(data_loaders)}')
for i, data_batch in enumerate(data_loaders): # data_batch ['img_metas', 'img', 'gt_bboxes', 'gt_labels']
img_metas_batch = data_batch['img_metas'].data[0] # len = 2
img_batch = data_batch['img'].data[0] # [2, 3, 736, 758]
gt_bboxes_batch = data_batch['gt_bboxes'].data[0] # ([n1, 4], [n2, 4]) = ([n1, [x, y, w, h]], [n2, [x, y, w, h]])
gt_labels_batch = data_batch['gt_labels'].data[0] # ([n1], [n2])
break
return img_metas_batch, img_batch, gt_bboxes_batch, gt_labels_batch