mmclassification-自定义数据管道(四)

1、数据管道设计

遵循典型的约定,我们使用Dataset并DataLoader用于多个工作单元的数据加载。Dataset返回与模型正向方法的参数相对应的数据项的字典。

数据准备管道和数据集是分开的。通常,数据集定义如何处理标记数据,数据管道定义所有准备数据字典的步骤。管道由一系列操作组成,每个操作都将一个dict作为输入,并为下一个转换输出一个dict。这些操作分为数据加载,预处理和格式化。

这是在ImageNet上进行ResNet-50训练的管道示例。

img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
    
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='RandomResizedCrop', size=224),
    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='ToTensor', keys=['gt_label']),
    dict(type='Collect', keys=['img', 'gt_label'])
]

test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', size=256),
    dict(type='CenterCrop', crop_size=224),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='ToTensor', keys=['gt_label']),
    dict(type='Collect', keys=['img', 'gt_label'])
]

由于一些错误,LoadImageFromFile会从磁盘加载图像,但这可能会导致有效小型模型的IO瓶颈。mmcv支持各种后端以加速此过程。例如,如果训练机已设置 memcached,则可以如下修改配置。

memcached_root = '/mnt/xxx/memcached_client/'
train_pipeline = [
    dict(
        type='LoadImageFromFile',
        file_client_args=dict(
            backend='memcached',
            server_list_cfg=osp.join(memcached_root, 'server_list.conf'),
            client_cfg=osp.join(memcached_root, 'client.conf'))),
]

mmcv.fileio.FileClient中可以找到更多受支持的后端。

对于每个操作,我们列出了添加/更新/删除的相关字典字段。在管道的最后,我们Collect只保留用于正向计算的必要项目。

1.1、数据载入

LoadImageFromFile

添加:img,img_shape,ori_shape

1.2、预处理

Resize

添加:scale,scale_idx,pad_shape,scale_factor,keep_ratio

更新:img,img_shape

RandomFlip

添加:翻转,flip_direction

更新:img

RandomCrop

更新:img,pad_shape

Normalize

添加:img_norm_cfg

更新:img

1.3、格式化

ToTensor

更新:由指定keys。

ImageToTensor

更新:由指定keys。

Transpose

更新:由指定keys。

Collect

删除:除由所指定的键以外的所有其他键 keys

2、扩展和使用自定义管道

step1、在任何文件中编写新管道,如my_pipeline.py。它以字典作为输入并返回一个字典。

from mmcls.datasets import PIPELINES

@PIPELINES.register_module()
class MyTransform(object):

    def __call__(self, results):
        results['dummy'] = True
        # apply transforms on results['img']
        return results

step2、导入新类。

from .my_pipeline import MyTransform

step3、在配置文件中使用它。

img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='RandomResizedCrop', size=224),
    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
    dict(type='MyTransform'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='ToTensor', keys=['gt_label']),
    dict(type='Collect', keys=['img', 'gt_label'])
]

传送门:mmclassification项目阅读系列文章目录

教程文档翻译:

mmclassification-安装使用(一)

mmclassification-模型微调(二)

mmclassification-添加新数据集(三)

mmclassification-自定义数据管道(四)

mmclassification-添加新模块(五)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值