1、数据管道设计
遵循典型的约定,我们使用Dataset并DataLoader用于多个工作单元的数据加载。Dataset返回与模型正向方法的参数相对应的数据项的字典。
数据准备管道和数据集是分开的。通常,数据集定义如何处理标记数据,数据管道定义所有准备数据字典的步骤。管道由一系列操作组成,每个操作都将一个dict作为输入,并为下一个转换输出一个dict。这些操作分为数据加载,预处理和格式化。
这是在ImageNet上进行ResNet-50训练的管道示例。
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', size=224),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', size=256),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
由于一些错误,LoadImageFromFile会从磁盘加载图像,但这可能会导致有效小型模型的IO瓶颈。mmcv支持各种后端以加速此过程。例如,如果训练机已设置 memcached,则可以如下修改配置。
memcached_root = '/mnt/xxx/memcached_client/'
train_pipeline = [
dict(
type='LoadImageFromFile',
file_client_args=dict(
backend='memcached',
server_list_cfg=osp.join(memcached_root, 'server_list.conf'),
client_cfg=osp.join(memcached_root, 'client.conf'))),
]
在mmcv.fileio.FileClient中可以找到更多受支持的后端。
对于每个操作,我们列出了添加/更新/删除的相关字典字段。在管道的最后,我们Collect只保留用于正向计算的必要项目。
1.1、数据载入
LoadImageFromFile
添加:img,img_shape,ori_shape
1.2、预处理
Resize
添加:scale,scale_idx,pad_shape,scale_factor,keep_ratio
更新:img,img_shape
RandomFlip
添加:翻转,flip_direction
更新:img
RandomCrop
更新:img,pad_shape
Normalize
添加:img_norm_cfg
更新:img
1.3、格式化
ToTensor
更新:由指定keys。
ImageToTensor
更新:由指定keys。
Transpose
更新:由指定keys。
Collect
删除:除由所指定的键以外的所有其他键 keys
2、扩展和使用自定义管道
step1、在任何文件中编写新管道,如my_pipeline.py。它以字典作为输入并返回一个字典。
from mmcls.datasets import PIPELINES
@PIPELINES.register_module()
class MyTransform(object):
def __call__(self, results):
results['dummy'] = True
# apply transforms on results['img']
return results
step2、导入新类。
from .my_pipeline import MyTransform
step3、在配置文件中使用它。
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', size=224),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='MyTransform'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
传送门:mmclassification项目阅读系列文章目录
教程文档翻译: