【mmdetection】中dataloader加载COCO数据集

【mmdetection】中dataloader加载COCO数据集
时间:2022年9月10日

平常调试mmdetection代码,需要载入部分数据,所以写个脚本,方便数据加载。

参考:
datasets的构建参考./tools/train.py
data_loaders的构建参考./mmdet/apis/train.py

注意:
datasets载入的数据就已经是数据增强后的了,已经是经过缩放、翻转、正则化、填充后的了

代码如下:

from mmdet.datasets import build_dataset, build_dataloader
from mmcv import Config
from mmdet.apis import init_random_seed, set_random_seed
import torch.distributed as dist
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')
    parser.add_argument('--config', default='./configs/deformable_detr/deformable_detr_r50_16x2_50e_coco.py', help='train config file path')
    parser.add_argument('--seed', type=int, default=None, help='random seed')
    parser.add_argument(
        '--diff-seed',
        action='store_true',
        help='Whether or not set different seeds for different ranks')
    parser.add_argument(
        '--deterministic',
        action='store_true',
        help='whether to set deterministic options for CUDNN backend.')
    group_gpus = parser.add_mutually_exclusive_group()
    group_gpus.add_argument(
        '--gpus',
        type=int,
        help='(Deprecated, please use --gpu-id) number of gpus to use '
             '(only applicable to non-distributed training)')
    group_gpus.add_argument(
        '--gpu-ids',
        type=int,
        nargs='+',
        help='(Deprecated, please use --gpu-id) ids of gpus to use '
             '(only applicable to non-distributed training)')
    group_gpus.add_argument(
        '--gpu-id',
        type=int,
        default=0,
        help='id of gpu to use '
             '(only applicable to non-distributed training)')
    args = parser.parse_args()
    return args

def get_data():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    datasets = build_dataset(cfg.data.train)
    print(f'datasets build finish! total : {len(datasets)}')

    # set random seeds
    seed = init_random_seed(args.seed)
    seed = seed + dist.get_rank() if args.diff_seed else seed
    set_random_seed(seed, deterministic=args.deterministic)
    cfg.seed = seed

    # set gpu_ids
    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids
    else:
        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
    # set runner_type
    runner_type = 'EpochBasedRunner' if 'runner' not in cfg else cfg.runner[
        'type']

    data_loaders = build_dataloader(
        datasets,
        samples_per_gpu=2,
        workers_per_gpu=2,
        # `num_gpus` will be ignored if distributed
        num_gpus=len(cfg.gpu_ids),
        dist=False,
        seed=cfg.seed,
        runner_type=runner_type,
        persistent_workers=cfg.data.get('persistent_workers', False))

    print(f'data_loaders build finish! total : {len(data_loaders)}')

    for i, data_batch in enumerate(data_loaders):  # data_batch ['img_metas', 'img', 'gt_bboxes', 'gt_labels']
        img_metas_batch = data_batch['img_metas'].data[0]  # len = 2
        img_batch = data_batch['img'].data[0]  # [2, 3, 736, 758]
        gt_bboxes_batch = data_batch['gt_bboxes'].data[0]  # ([n1, 4], [n2, 4]) = ([n1, [x, y, w, h]], [n2, [x, y, w, h]])
        gt_labels_batch = data_batch['gt_labels'].data[0]  # ([n1], [n2])
        break
        
    return img_metas_batch, img_batch, gt_bboxes_batch, gt_labels_batch
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
好的,数据集加载完成后,我们通常会使用dataloader来准备数据,以便于模型训练。dataloader可以帮助我们对数据进行批量处理、乱序、并行读取等操作,提高模型训练效率。 以文本数据集为例,我们可以先将每个文本转换成数字序列,并将所有文本序列对齐到相同长度,然后将所有文本序列打包成batch,供模型训练使用。具体操作可以参考以下代码: ```python import torch.utils.data as data from torch.nn.utils.rnn import pad_sequence import numpy as np class TextDataset(data.Dataset): def __init__(self, data): self.data = data self.tokenizer = Tokenizer() def __getitem__(self, index): # 将文本转换成数字序列 text = self.data[index] tokens = self.tokenizer.tokenize(text) ids = self.tokenizer.convert_tokens_to_ids(tokens) return ids def __len__(self): return len(self.data) def collate_fn(batch): # 将序列对齐到相同长度 lengths = [len(seq) for seq in batch] max_len = max(lengths) padded_seqs = [seq + [0]*(max_len-len(seq)) for seq in batch] padded_seqs = torch.LongTensor(padded_seqs) return padded_seqs, lengths # 加载数据集 data = ['text1', 'text2', 'text3', 'text4', 'text5'] dataset = TextDataset(data) dataloader = data.DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn, num_workers=0) # 使用dataloader训练模型 for batch in dataloader: inputs, lengths = batch outputs = model(inputs) loss = criterion(outputs, targets) optimizer.zero_grad() loss.backward() optimizer.step() ``` 在上面的代码,我们自定义了一个TextDataset类,用于将文本数据转换成数字序列。在collate_fn函数,我们将所有序列对齐到相同长度,并将它们打包成batch。最后使用DataLoader加载数据集,并传入collate_fn函数进行处理。 需要根据具体的数据集格式和模型需求来选择相应的方法。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值