1. 前言
这篇博客讨论mmDetection
框架的数据输入过程。主要讨论SA-SSD
的数据流。
2. 数据输入
2.1 DataLoader类
Pytorch
训练网络的简单例程可参考这篇简书笔记。数据输入的代码片段如下所示。Pytorch
提供DataLoader
类用于处理数据的输入,提供做训练的网络输入值(inputs
)和真值(labels
)。它提供很多数据输入的便捷操作,比如输入数据批处理大小(batch_size
),数据集是否需要随机打乱(shuffle
),数据集读取使用的线程个数(num_workers
)。
from torch.utils.data import DataLoader
trainloader = DataLoader(dataset=trainset, batch_size=4, shuffle=True, num_workers=4)
for i, data in enumerate(trainloader, 0):
# get the input
inputs, labels = data
对于DataLoader
而言,我最为关心的问题是:它是怎样读取输入数据的?从代码上看,通过调用enumerate
函数读取输入数据。再刨根问底地分析,enumerate
函数则是调用DataLoader
类的类成员变量trainset
的__getitem__
函数。trainset
是一个类,它公有继承自类dataset
。说起来是不是有点绕呢?我用一个示意图去解释它们。在图中,Get Data Function
是使用者需要定义的函数,比如训练数据读取,数据增广等。
图1:DataLoader
的调用关系图
2.2 mmDetection中的数据输入
上一节讨论了Pytorch
中Dataloader
类的原理图示,这一节看看mmDetection
中数据输入的原理图示。有一点要明确的是,mmDetection
的底层调用的还是Dataloader
。下面的原理图示以三维目标检测SA-SSD
为例子。
图2:mmDetection
初始化DataLoader
类的图示
由上图可见,mmDetection
首先调用函数get_dataset
,根据配置文件cfg.data.train
初始化一个数据集处理的类KittiLiDAR
(相当于图1中的类trainset
),然后再调用函数build_dataloader
,使用类KittiLiDAR
初始化类Dataloader
。
上述的流程需要频繁地做类实例化操作,需要用到函数obj_from_dict
。初始化一个类需要初始条件(比如img=cam_1
,flip=True
等等),我们把初始条件写进一个字典型变量里面(比如{"img":cam_1,"flip":True,...}
),然后调用函数obj_from_dict
,输入需要初始化类的类型以及初始化字典型变量,就可以完成这个类的初始化。
3. 结束语
这篇博客讨论了Dataloader
类访问训练数据的流程,以及mmDetection
调用Dataloader
类的流程。