第十章 MMDetection3D解析系列一_数据集(dataset)(车道线感知)

@[TOC](第十章 MMDetection3D解析系列_数据集(dataset)(车道线感知))

一 前言

近期参与到了手写AI的车道线检测的学习中去,以此系列笔记记录学习与思考的全过程。车道线检测系列会持续更新,力求完整精炼,引人启示。所需前期知识,可以结合手写AI进行系统的学习。

二 概述

数据集与数据加载器是MMEngine中训练流程的必要组件,它们的概念来源于 PyTorch数据集,并且在含义上与 PyTorch 保持一致。通常来说,数据集定义了数据的总体数量、读取方式以及预处理,而数据加载器则在不同的设置下迭代地加载数据,如批次大小(batch_size)、随机乱序(shuffle)、并行(num_workers)等。数据集经过数据加载器封装后构成了数据源。我们将按照从外(数据加载器)到内(数据集)的顺序,逐步介绍它们在 MMEngine 执行器中的用法,并给出一些常用示例。你将会:

  1. 掌握如何在 MMEngine 的执行器中配置数据加载器
  2. 学会在配置文件中使用已有(如 torchvision)数据集
  3. 了解如何使用自己的数据集

三 数据加载器详解

在执行器(Runner)中,你可以分别配置以下 3 个参数来指定对应的数据加载器

train_dataloader:在 Runner.train() 中被使用,为模型提供训练数据
val_dataloader:在 Runner.val() 中被使用,也会在 Runner.train() 中每间隔一段时间被使用,用于模型的验证评测
test_dataloader:在 Runner.test() 中被使用,用于模型的测试

MMEngine 完全支持 PyTorch 的原生 DataLoader,因此上述 3 个参数均可以直接传入构建好的 DataLoader。

3.1 示例

示例:以在 CIFAR-10 数据集上训练一个 ResNet-50 模型为例,我们将使用 80 行以内的代码,利用 MMEngine 构建一个完整的、 可配置的训练和验证流程,整个流程包含如下步骤:

构建模型
构建数据集和数据加载器
构建评测指标
构建执行器并执行任务

3.1.1 构建模型

首先,我们需要构建一个模型,在 MMEngine 中,我们约定这个模型应当继承 BaseModel,并且其 forward 方法除了接受来自数据集的若干参数外,还需要接受额外的参数 mode:对于训练,我们需要 mode 接受字符串 “loss”,并返回一个包含 “loss” 字段的字典;对于验证,我们需要 mode 接受字符串 “predict”,并返回同时包含预测信息和真实信息的结果。

import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel


class MMResNet50(BaseModel):
    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet50()

   def forward(self, imgs, labels, mode):
       x = self.resnet(imgs)
       if mode == 'loss':
           return {'loss': F.cross_entropy(x, labels)}
       elif mode == 'predict':
           return x, labels

3.1.2 构建数据集和数据加载器

其次,我们需要构建训练和验证所需要的数据集 (Dataset)和数据加载器 (DataLoader)。 对于基础的训练和验证功能,我们可以直接使用符合 PyTorch 标准的数据加载器和数据集。

import torchvision.transforms as transforms
from torch.utils.data import DataLoader

norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(batch_size=32,
                              shuffle=True,
                              dataset=torchvision.datasets.CIFAR10(
                                  'data/cifar10',
                                  train=True,
                                  download=True,
                                  transform=transforms.Compose([
                                      transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize(**norm_cfg)
                                  ])))

val_dataloader = DataLoader(batch_size=32,
                            shuffle=False,
                            dataset=torchvision.datasets.CIFAR10(
                                'data/cifar10',
                                train=False,
                                download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize(**norm_cfg)
                                ])))

3.2.3 构建评测指标

为了进行验证和测试,我们需要定义模型推理结果的评测指标。我们约定这一评测指标需要继承 BaseMetric,并实现 process 和 compute_metrics 方法。其中 process 方法接受数据集的输出和模型 mode=“predict” 时的输出,此时的数据为一个批次的数据,对这一批次的数据进行处理后,保存信息至 self.results 属性。 而 compute_metrics 接受 results 参数,这一参数的输入为 process 中保存的所有信息 (如果是分布式环境,results 中为已收集的,包括各个进程 process 保存信息的结果),利用这些信息计算并返回保存有评测指标结果的字典。

from mmengine.evaluator import BaseMetric

class Accuracy(BaseMetric):
    def process(self, data_batch, data_samples):
        score, gt = data_samples
        # 将一个批次的中间结果保存至 `self.results`
        self.results.append({
            'batch_size': len(gt),
            'correct': (score.argmax(dim=1) == gt).sum().cpu(),
        })

    def compute_metrics(self, results):
        total_correct = sum(item['correct'] for item in results)
        total_size = sum(item['batch_size'] for item in results)
        # 返回保存有评测指标结果的字典,其中键为指标名称
        return dict(accuracy=100 * total_correct / total_size)

3.2.4 构建执行器并执行任务

最后,我们利用构建好的模型,数据加载器,评测指标构建一个**执行器 (Runner),**同时在其中配置优化器、工作路径、训练与验证配置等选项,即可通过调用 train() 接口启动训练:

from torch.optim import SGD
from mmengine.runner import Runner

runner = Runner(
    # 用以训练和验证的模型,需要满足特定的接口需求
    model=MMResNet50(),
    # 工作路径,用以保存训练日志、权重文件信息
    work_dir='./work_dir',
    # 训练数据加载器,需要满足 PyTorch 数据加载器协议
    train_dataloader=train_dataloader,
    # 优化器包装,用于模型优化,并提供 AMP、梯度累积等附加功能
    optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
    # 训练配置,用于指定训练周期、验证间隔等信息
    train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
    # 验证数据加载器,需要满足 PyTorch 数据加载器协议
    val_dataloader=val_dataloader,
    # 验证配置,用于指定验证所需要的额外参数
    val_cfg=dict(),
    # 用于验证的评测器,这里使用默认评测器,并评测指标
    val_evaluator=dict(type=Accuracy),
)

runner.train()

3.2.5 汇总代码

import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.optim import SGD
from torch.utils.data import DataLoader

from mmengine.evaluator import BaseMetric
from mmengine.model import BaseModel
from mmengine.runner import Runner


class MMResNet50(BaseModel):
    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet50()

    def forward(self, imgs, labels, mode):
        x = self.resnet(imgs)
        if mode == 'loss':
            return {'loss': F.cross_entropy(x, labels)}
        elif mode == 'predict':
            return x, labels


class Accuracy(BaseMetric):
    def process(self, data_batch, data_samples):
        score, gt = data_samples
        self.results.append({
            'batch_size': len(gt),
            'correct': (score.argmax(dim=1) == gt).sum().cpu(),
        })

    def compute_metrics(self, results):
        total_correct = sum(item['correct'] for item in results)
        total_size = sum(item['batch_size'] for item in results)
        return dict(accuracy=100 * total_correct / total_size)


norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(batch_size=32,
                              shuffle=True,
                              dataset=torchvision.datasets.CIFAR10(
                                  'data/cifar10',
                                  train=True,
                                  download=True,
                                  transform=transforms.Compose([
                                      transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize(**norm_cfg)
                                  ])))

val_dataloader = DataLoader(batch_size=32,
                            shuffle=False,
                            dataset=torchvision.datasets.CIFAR10(
                                'data/cifar10',
                                train=False,
                                download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize(**norm_cfg)
                                ])))

runner = Runner(
    model=MMResNet50(),
    work_dir='./work_dir',
    train_dataloader=train_dataloader,
    optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
    train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
    val_dataloader=val_dataloader,
    val_cfg=dict(),
    val_evaluator=dict(type=Accuracy),
)
runner.train()
runner = Runner(
    train_dataloader=dict(
        batch_size=32,
        sampler=dict(
            type='DefaultSampler',
            shuffle=True),
        dataset=torchvision.datasets.CIFAR10(...),
        collate_fn=dict(type='default_collate')
    )
)

在这种情况下,数据加载器会在实际被用到时,在执行器内部被构建。
更多DataLoader build_dataloader

3.2 sampler 与 shuffle

对比Runner可知,本以为将 DataLoader 简单替换为 dict 就可以无缝切换,但遗憾的是,基于注册机制构建时 MMEngine会有一些隐式的转换和约定。我们将介绍其中的不同点。

我们添加了 sampler 参数,这是由于在 MMEngine 中我们要求通过 dict 传入的数据加载器的配置必须包含 sampler 参数。同时,shuffle 参数也从 DataLoader 中移除,这是由于在 PyTorch 中 sampler 与 shuffle 参数是互斥的。事实上,在 PyTorch 的实现中,shuffle 只是一个便利记号。当设置为 True 时 DataLoader 会自动在内部使用 RandomSampler

from mmengine.dataset import DefaultSampler

dataset = torchvision.datasets.CIFAR10(...)
sampler = DefaultSampler(dataset, shuffle=True)

runner = Runner(
    train_dataloader=DataLoader(
        batch_size=32,
        sampler=sampler,
        dataset=dataset,
        collate_fn=default_collate
    )
)
# 对比
runner = Runner(
    train_dataloader=dict(
        batch_size=32,
        sampler=dict(
            type='DefaultSampler',
            shuffle=True),
        dataset=torchvision.datasets.CIFAR10(...),
        collate_fn=dict(type='default_collate')
    )
)

上述代码的等价性只有在:1)使用单进程训练,以及 2)没有配置执行器的 randomness 参数时成立。这是由于使用 dict 传入 sampler 时,执行器会保证它在分布式训练环境设置完成后才被惰性构造,并接收到正确的随机种子。这两点在手动构造时需要额外工作且极易出错。因此,上述的写法只是一个示意而非推荐写法。我们强烈建议 sampler 以 dict 的形式传入,让执行器处理构造顺序,以避免出现问题。

3.2.1 DefaultSampler

上面例子可能会让你好奇:DefaultSampler 是什么,为什么要使用它,是否有其他选项?事实上,DefaultSampler 是 MMEngine 内置的一种采样器,它屏蔽了单进程训练与多进程训练的细节差异,使得单卡与多卡训练可以无缝切换。但在 MMEngine 中,这一细节通过 DefaultSampler 而被屏蔽。

除了 Dataset 本身之外,DefaultSampler 还支持以下参数配置:

shuffle 设置为 True 时会打乱数据集的读取顺序
seed 打乱数据集所用的随机种子,通常不需要在此手动设置,会从 Runner 的 randomness 入参中读取
round_up 设置为 True 时,与 PyTorch DataLoader 中设置 drop_last=False 行为一致。如果你在迁移 PyTorch 的项目,你可能需要注意这一点。

如果你想要使用基于迭代次数 (iteration-based) 的训练流程,你也许会对 InfiniteSampler感兴趣

3.2.2 自定义采样

你可能会想要参考上述两个内置 sampler 的代码,实现一个自定义的 sampler 并注册到 DATA_SAMPLERS 根注册器中。

@DATA_SAMPLERS.register_module()
class MySampler(Sampler):
    pass

runner = Runner(
    train_dataloader=dict(
        sampler=dict(type='MySampler'),
        ...
    )
)

3.3 collate_fn

MMengine 中提供了 2 种内置的 collate_fn:

pseudo_collate,缺省时的默认参数。它不会将数据沿着 batch 的维度合并。详细说明可以参考 pseudo_collate

default_collate,与 PyTorch 中的 default_collate 行为几乎完全一致,会将数据转化为 Tensor 并沿着 batch 维度合并。一些细微不同和详细说明可以参考 default_collate

如果你想要使用自定义的 collate_fn,你也可以将它注册到 FUNCTIONS 根注册器中来使用

@FUNCTIONS.register_module()
def my_collate_func(data_batch: Sequence) -> Any:
    pass

runner = Runner(
    train_dataloader=dict(
        ...
        collate_fn=dict(type='my_collate_func')
    )
)

3.4 torchvision数据集

行注册和构建

import torchvision.transforms as tvt
from mmengine.registry import DATASETS, TRANSFORMS
from mmengine.dataset.base_dataset import Compose

# 注册 torchvision 的 CIFAR10 数据集
# 数据预处理也需要在此一起构建
@DATASETS.register_module(name='Cifar10', force=False)
def build_torchvision_cifar10(transform=None, **kwargs):
    if isinstance(transform, dict):
        transform = [transform]
    if isinstance(transform, (list, tuple)):
        transform = Compose(transform)
    return torchvision.datasets.CIFAR10(**kwargs, transform=transform)

# 注册 torchvision 中用到的数据预处理模块
DATA_TRANSFORMS.register_module('RandomCrop', module=tvt.RandomCrop)
DATA_TRANSFORMS.register_module('RandomHorizontalFlip', module=tvt.RandomHorizontalFlip)
DATA_TRANSFORMS.register_module('ToTensor', module=tvt.ToTensor)
DATA_TRANSFORMS.register_module('Normalize', module=tvt.Normalize)

# 在 Runner 中使用
runner = Runner(
    train_dataloader=dict(
        batch_size=32,
        sampler=dict(
            type='DefaultSampler',
            shuffle=True),
        dataset=dict(type='Cifar10',
            root='data/cifar10',
            train=True,
            download=True,
            transform=[
                dict(type='RandomCrop', size=32, padding=4),
                dict(type='RandomHorizontalFlip'),
                dict(type='ToTensor'),
                dict(type='Normalize', **norm_cfg)])
    )
)

3.5 创建自定义数据集

自定义数据集类必须实现三个函数:initlen__和__getitem。 看看这个实现;存储了时尚MNIST图像 在目录中,并且它们的标签单独存储在CSV文件中。img_dirannotations_file

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

四 MMEngine 的数据集基类

4.1 基本介绍

因此 MMEngine 实现了一个数据集基类(BaseDataset)并定义了一些基本接口,且基于这套接口实现了一些数据集包装(DatasetWrapper)。OpenMMLab 算法库中的大部分数据集都会满足这套数据集基类定义的接口,并使用统一的数据集包装。

数据集基类的基本功能是加载数据集信息,这里我们将数据集信息分成两类,一种是元信息 (meta information),代表数据集自身相关的信息,有时需要被模型或其他外部组件获取,比如在图像分类任务中,数据集的元信息一般包含类别信息 classes,因为分类模型 model 一般需要记录数据集的类别信息;另一种为数据信息 (data information),在数据信息中,定义了具体样本的文件路径、对应标签等的信息。除此之外,数据集基类的另一个功能为不断地将数据送入数据流水线(data pipeline)中,进行数据预处理。

4.2 数据标注文件规范

OpenMMLab 2.0 数据集格式规范规定,标注文件必须为 json 或 yaml,yml 或 pickle,pkl 格式;标注文件中存储的字典必须包含 metainfo 和 data_list 两个字段。其中 metainfo 是一个字典,里面包含数据集的元信息;data_list 是一个列表,列表中每个元素是一个字典,该字典定义了一个原始数据(raw data),每个原始数据包含一个或若干个训练/测试样本。

以下是一个 JSON 标注文件的例子(该例子中每个原始数据只包含一个训练/测试样本):

{
    "metainfo":
        {
            "classes": ["cat", "dog"]
        },
    "data_list":
        [
            {
                "img_path": "xxx/xxx_0.jpg",
                "img_label": 0
            },
            {
                "img_path": "xxx/xxx_1.jpg",
                "img_label": 1
            }
        ]
}

同时假设数据存放路径如下:

data
├── annotations
│   ├── train.json
├── train
│   ├── xxx/xxx_0.jpg
│   ├── xxx/xxx_1.jpg
│   ├── ...

4.3 数据集基类的初始化流程

初始化流程

load metainfo:获取数据集的元信息,元信息有三种来源,优先级从高到低为:
init() 方法中用户传入的 metainfo 字典;改动频率最高,因为用户可以在实例化数据集时,传入该参数;
类属性 BaseDataset.METAINFO 字典;改动频率中等,因为用户可以改动自定义数据集类中的类属性 BaseDataset.METAINFO;
标注文件中包含的 metainfo 字典;改动频率最低,因为标注文件一般不做改动。
如果三种来源中有相同的字段,优先级最高的来源决定该字段的值,这些字段的优先级比较是:用户传入的 metainfo 字典里的字段 > BaseDataset.METAINFO 字典里的字段 > 标注文件中 metainfo 字典里的字段。
join path:处理数据与标注文件的路径;
build pipeline:构建数据流水线(data pipeline),用于数据预处理与数据准备;
full init:完全初始化数据集类,该步骤主要包含以下操作:
load data list:读取与解析满足 OpenMMLab 2.0 数据集格式规范的标注文件,该步骤中会调用 parse_data_info() 方法,该方法负责解析标注文件里的每个原始数据;
filter data (可选):根据 filter_cfg 过滤无用数据,比如不包含标注的样本等;默认不做过滤操作,下游子类可以按自身所需对其进行重写;
get subset (可选):根据给定的索引或整数值采样数据,比如只取前 10 个样本参与训练/测试;默认不采样数据,即使用全部数据样本;
serialize data (可选):序列化全部样本,以达到节省内存的效果,详情请参考节省内存;默认操作为序列化全部样本。

4.4 数据集基类提供的接口

与 torch.utils.data.Dataset 类似,数据集初始化后,支持 getitem 方法,用来索引数据,以及 len 操作获取数据集大小,除此之外,OpenMMLab 的数据集基类主要提供了以下接口来访问具体信息:

metainfo:返回元信息,返回值为字典
get_data_info(idx):返回指定 idx 的样本全量信息,返回值为字典
getitem(idx):返回指定 idx 的样本经过 pipeline 之后的结果(也就是送入模型的数据),返回值为字典
len():返回数据集长度,返回值为整数型
get_subset_(indices):根据 indices 以 inplace 的方式修改原数据集类。如果 indices 为 int,则原数据集类只包含前若干个数据样本;如果 indices 为 Sequence[int],则原数据集类包含根据 Sequence[int] 指定的数据样本。
get_subset(indices):根据 indices 以非 inplace 的方式返回子数据集类,即重新复制一份子数据集。如果 indices 为 int,则返回的子数据集类只包含前若干个数据样本;如果 indices 为 Sequence[int],则返回的子数据集类包含根据 Sequence[int] 指定的数据样本。

五 使用数据集基类自定义数据集类

import os.path as osp

from mmengine.dataset import BaseDataset


class ToyDataset(BaseDataset):

    # 以上面标注文件为例,在这里 raw_data_info 代表 `data_list` 对应列表里的某个字典:
    # {
    #    'img_path': "xxx/xxx_0.jpg",
    #    'img_label': 0,
    #    ...
    # }
    def parse_data_info(self, raw_data_info):
        data_info = raw_data_info
        img_prefix = self.data_prefix.get('img_path', None)
        if img_prefix is not None:
            data_info['img_path'] = osp.join(
                img_prefix, data_info['img_path'])
        return data_info

class LoadImage:

    def __call__(self, results):
        results['img'] = cv2.imread(results['img_path'])
        return results

class ParseImage:

    def __call__(self, results):
        results['img_shape'] = results['img'].shape
        return results

pipeline = [
    LoadImage(),
    ParseImage(),
]
# 在定义了数据集类后,就可以通过如下配置实例化 ToyDataset:
toy_dataset = ToyDataset(
    data_root='data/',
    data_prefix=dict(img_path='train/'),
    ann_file='annotations/train.json',
    pipeline=pipeline)
    data_root: 数据的根目录,指定数据存储的主目录。
'''data_prefix: 数据路径的前缀,用于构建完整的数据路径。在这个示例中,通过dict(img_path='train/')将img_path的前缀设置为'train/
ann_file: 注释文件的路径,指定包含数据文件标注信息的文件。
pipeline: 数据处理的流程,定义了一系列数据处理类的顺序。在这个示例中,pipeline列表中包含了LoadImage()和ParseImage()两个数据处理类的对象。'''

同时可以使用数据集类提供的对外接口访问具体的样本信息:

toy_dataset.metainfo
# dict(classes=('cat', 'dog'))

toy_dataset.get_data_info(0)
# {
#     'img_path': "data/train/xxx/xxx_0.jpg",
#     'img_label': 0,
#     ...
# }

len(toy_dataset)
# 2

toy_dataset[0]
# {
#     'img_path': "data/train/xxx/xxx_0.jpg",
#     'img_label': 0,
#     'img': a ndarray with shape (H, W, 3), which denotes the value of the image,
#     'img_shape': (H, W, 3) ,
#     ...
# }

# `get_subset` 接口不对原数据集类做修改,即完全复制一份新的
sub_toy_dataset = toy_dataset.get_subset(1)
len(toy_dataset), len(sub_toy_dataset)
# 2, 1

# `get_subset_` 接口会对原数据集类做修改,即 inplace 的方式
toy_dataset.get_subset_(1)
len(toy_dataset)
# 1

六 自定义视频的数据集类

在上面的例子中,标注文件的每个原始数据只包含一个训练/测试样本(通常是图像领域)。如果每个原始数据包含若干个训练/测试样本(通常是视频领域),则只需保证 parse_data_info() 的返回值为 list[dict] 即可:

from mmengine.dataset import BaseDataset


class ToyVideoDataset(BaseDataset):

    # raw_data_info 仍为一个字典,但它包含了多个样本
    def parse_data_info(self, raw_data_info):
        data_list = []

        ...

        for ... :

            data_info = dict()

            ...

            data_list.append(data_info)

        return data_list

对于不满足 OpenMMLab 2.0 数据集格式规范的标注文件

将不满足规范的标注文件转换成满足规范的标注文件,再通过上述方式使用数据集基类。
实现一个新的数据集类,继承自数据集基类,并且重载数据集基类的 load_data_list(self): 函数,处理不满足规范的标注文件,并保证返回值为 list[dict],其中每个 dict 代表一个数据样本。

七 数据集基类的其它特性

在数据集类实例化时,需要读取并解析标注文件,因此会消耗一定时间。然而在某些情况比如预测可视化时,往往只需要数据集类的元信息,可能并不需要读取与解析标注文件。为了节省这种情况下数据集类实例化的时间,数据集基类支持懒加载(lazy init):

pipeline = [
    LoadImage(),
    ParseImage(),
]

toy_dataset = ToyDataset(
    data_root='data/',
    data_prefix=dict(img_path='train/'),
    ann_file='annotations/train.json',
    pipeline=pipeline,
    # 在这里传入 lazy_init 变量
    lazy_init=True)

在数据集类实例化时,需要读取并解析标注文件,因此会消耗一定时间。然而在某些情况比如预测可视化时,往往只需要数据集类的元信息,可能并不需要读取与解析标注文件。为了节省这种情况下数据集类实例化的时间,数据集基类支持懒加载:

pipeline = [
    LoadImage(),
    ParseImage(),
]

toy_dataset = ToyDataset(
    data_root='data/',
    data_prefix=dict(img_path='train/'),
    ann_file='annotations/train.json',
    pipeline=pipeline,
    # 在这里传入 lazy_init 变量
    lazy_init=True)

当 lazy_init=True 时,ToyDataset 的初始化方法只执行了数据集基类的初始化流程中的 1、2、3 步骤,此时 toy_dataset 并未被完全初始化,因为 toy_dataset 并不会读取与解析标注文件,只会设置数据集类的元信息(metainfo)。

自然的,如果之后需要访问具体的数据信息,可以手动调用 toy_dataset.full_init() 接口来执行完整的初始化过程,在这个过程中数据标注文件将被读取与解析。调用 get_data_info(idx), len(), getitem(idx),get_subset_(indices), get_subset(indices) 接口也会自动地调用 full_init() 接口来执行完整的初始化过程(仅在第一次调用时,之后调用不会重复地调用 full_init() 接口):

# 完整初始化
toy_dataset.full_init()

# 初始化完毕,现在可以访问具体数据
len(toy_dataset)
# 2
toy_dataset[0]
# {
#     'img_path': "data/train/xxx/xxx_0.jpg",
#     'img_label': 0,
#     'img': a ndarray with shape (H, W, 3), which denotes the value the image,
#     'img_shape': (H, W, 3) ,
#     ...
# }

在具体的读取数据过程中,数据加载器(dataloader)通常会起多个 worker 来预取数据,多个 worker 都拥有完整的数据集类备份,因此内存中会存在多份相同的 data_list,为了节省这部分内存消耗,数据集基类可以提前将 data_list 序列化存入内存中,使得多个 worker 可以共享同一份 data_list,以达到节省内存的目的。

数据集基类默认是将 data_list 序列化存入内存,也可以通过 serialize_data 变量(默认为 True)来控制是否提前将 data_list 序列化存入内存中:

pipeline = [
    LoadImage(),
    ParseImage(),
]

toy_dataset = ToyDataset(
    data_root='data/',
    data_prefix=dict(img_path='train/'),
    ann_file='annotations/train.json',
    pipeline=pipeline,
    # 在这里传入 serialize_data 变量
    serialize_data=False)

上面例子不会提前将 data_list 序列化存入内存中,因此不建议在使用数据加载器开多个 worker 加载数据的情况下,使用这种方式实例化数据集类。

八 数据集基类包装

除了数据集基类,MMEngine 也提供了若干个数据集基类包装:ConcatDataset, RepeatDataset, ClassBalancedDataset。这些数据集基类包装同样也支持懒加载与拥有节省内存的特性。

8.1 ConcatDataset

MMEngine 提供了 ConcatDataset 包装来拼接多个数据集,使用方法如下:

from mmengine.dataset import ConcatDataset

pipeline = [
    LoadImage(),
    ParseImage(),
]

toy_dataset_1 = ToyDataset(
    data_root='data/',
    data_prefix=dict(img_path='train/'),
    ann_file='annotations/train.json',
    pipeline=pipeline)

toy_dataset_2 = ToyDataset(
    data_root='data/',
    data_prefix=dict(img_path='val/'),
    ann_file='annotations/val.json',
    pipeline=pipeline)

toy_dataset_12 = ConcatDataset(datasets=[toy_dataset_1, toy_dataset_2])

上述例子将数据集的 train 部分与 val 部分合成一个大的数据集。

8.2 RepeatDataset

MMEngine 提供了 RepeatDataset 包装来重复采样某个数据集若干次,使用方法如下:

from mmengine.dataset import RepeatDataset

pipeline = [
    LoadImage(),
    ParseImage(),
]

toy_dataset = ToyDataset(
    data_root='data/',
    data_prefix=dict(img_path='train/'),
    ann_file='annotations/train.json',
    pipeline=pipeline)

toy_dataset_repeat = RepeatDataset(dataset=toy_dataset, times=5)

8.3 ClassBalancedDataset

MMEngine 提供了 ClassBalancedDataset 包装,来基于数据集中类别出现频率,重复采样相应样本。

注意:
ClassBalancedDataset 包装假设了被包装的数据集类支持 get_cat_ids(idx) 方法,get_cat_ids(idx) 方法返回一个列表,该列表包含了 idx 指定的 data_info 包含的样本类别,使用方法如下:

from mmengine.dataset import BaseDataset, ClassBalancedDataset

class ToyDataset(BaseDataset):

    def parse_data_info(self, raw_data_info):
        data_info = raw_data_info
        img_prefix = self.data_prefix.get('img_path', None)
        if img_prefix is not None:
            data_info['img_path'] = osp.join(
                img_prefix, data_info['img_path'])
        return data_info

    # 必须支持的方法,需要返回样本的类别
    def get_cat_ids(self, idx):
        data_info = self.get_data_info(idx)
        return [int(data_info['img_label'])]

pipeline = [
    LoadImage(),
    ParseImage(),
]

toy_dataset = ToyDataset(
    data_root='data/',
    data_prefix=dict(img_path='train/'),
    ann_file='annotations/train.json',
    pipeline=pipeline)

toy_dataset_repeat = ClassBalancedDataset(dataset=toy_dataset, oversample_thr=1e-3)

上述例子将数据集的 train 部分以 oversample_thr=1e-3 重新采样,具体地,对于数据集中出现频率低于 1e-3 的类别,会重复采样该类别对应的样本,否则不重复采样,具体采样策略请参考 ClassBalancedDataset API 文档。

九 自定义数据集类包装

由于数据集基类实现了懒加载的功能,因此在自定义数据集类包装时,需要遵循一些规则,下面以一个例子的方式来展示如何自定义数据集类包装:

from mmengine.dataset import BaseDataset
from mmengine.registry import DATASETS


@DATASETS.register_module()
class ExampleDatasetWrapper:

    def __init__(self, dataset, lazy_init=False, ...):
        # 构建原数据集(self.dataset)
        if isinstance(dataset, dict):
            self.dataset = DATASETS.build(dataset)
        elif isinstance(dataset, BaseDataset):
            self.dataset = dataset
        else:
            raise TypeError(
                'elements in datasets sequence should be config or '
                f'`BaseDataset` instance, but got {type(dataset)}')
        # 记录原数据集的元信息
        self._metainfo = self.dataset.metainfo

        '''
        1. 在这里实现一些代码,来记录用于包装数据集的一些超参。
        '''

        self._fully_initialized = False
        if not lazy_init:
            self.full_init()

    def full_init(self):
        if self._fully_initialized:
            return

        # 将原数据集完全初始化
        self.dataset.full_init()

        '''
        2. 在这里实现一些代码,来包装原数据集。
        '''

        self._fully_initialized = True

    @force_full_init
    def _get_ori_dataset_idx(self, idx: int):

        '''
        3. 在这里实现一些代码,来将包装的索引 `idx` 映射到原数据集的索引 `ori_idx`。
        '''
        ori_idx = ...

        return ori_idx

    # 提供与 `self.dataset` 一样的对外接口。
    @force_full_init
    def get_data_info(self, idx):
        sample_idx = self._get_ori_dataset_idx(idx)
        return self.dataset.get_data_info(sample_idx)

    # 提供与 `self.dataset` 一样的对外接口。
    def __getitem__(self, idx):
        if not self._fully_initialized:
            warnings.warn('Please call `full_init` method manually to '
                          'accelerate the speed.')
            self.full_init()

        sample_idx = self._get_ori_dataset_idx(idx)
        return self.dataset[sample_idx]

    # 提供与 `self.dataset` 一样的对外接口。
    @force_full_init
    def __len__(self):

        '''
        4. 在这里实现一些代码,来计算包装数据集之后的长度。
        '''
        len_wrapper = ...

        return len_wrapper

    # 提供与 `self.dataset` 一样的对外接口。
    @property
    def metainfo(self)
        return copy.deepcopy(self._metainfo)
  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
你可以按照以下步骤将KITTI数据集转换为MMDetection3D格式: 1. 首先,确保你已经下载了KITTI数据集,并且文件结构如下所示: ``` KITTI ├── training │ ├── calib │ ├── image_2 │ ├── label_2 │ ├── velodyne │ └── ... └── testing ├── calib ├── image_2 ├── velodyne └── ... ``` 2. 然后,你需要安装MMDetection3D库。你可以按照官方文档中的指示进行安装:https://mmdetection3d.readthedocs.io/en/latest/getting_started.html#installation 3. 接下来,你需要创建一个配置文件,指定数据集的相关信息。在MMDetection3D中,配置文件通常是一个Python脚本。你可以在`configs/dataset`目录下找到示例配置文件,比如`kitti_dataset.py`。 4. 打开配置文件,并根据你的数据集路径进行相应的修改。主要需要修改的变量有: - `root_path`:指定KITTI数据集的根路径。 - `train_pipeline`和`test_pipeline`:指定数据预处理和增强的操作。 5. 保存并关闭配置文件。 6. 现在,你可以使用MMDetection3D提供的工具将KITTI数据集转换为MMDetection3D格式。在命令行中执行以下命令: ``` python tools/data_converter/kitti_converter.py <path_to_config_file> ``` 其中,`<path_to_config_file>`是你刚刚创建的配置文件的路径。 7. 执行上述命令后,MMDetection3D将会将KITTI数据集转换为MMDetection3D格式,并保存在指定的输出路径中。 完成上述步骤后,你就成功将KITTI数据集转换为MMDetection3D格式了。你可以使用转换后的数据集进行目标检测和3D物体检测任务。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小酒馆燃着灯

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值