小白科研笔记:深入理解mmDetection框架——数据输入

1. 前言

这篇博客讨论mmDetection框架的数据输入过程。主要讨论SA-SSD的数据流。

2. 数据输入

2.1 DataLoader类

Pytorch训练网络的简单例程可参考这篇简书笔记。数据输入的代码片段如下所示。Pytorch提供DataLoader类用于处理数据的输入,提供做训练的网络输入值(inputs)和真值(labels)。它提供很多数据输入的便捷操作,比如输入数据批处理大小(batch_size),数据集是否需要随机打乱(shuffle),数据集读取使用的线程个数(num_workers)。

from torch.utils.data import DataLoader

trainloader = DataLoader(dataset=trainset, batch_size=4, shuffle=True, num_workers=4)

for i, data in enumerate(trainloader, 0):
    # get the input
    inputs, labels = data

对于DataLoader而言,我最为关心的问题是:它是怎样读取输入数据的?从代码上看,通过调用enumerate函数读取输入数据。再刨根问底地分析,enumerate函数则是调用DataLoader类的类成员变量trainset__getitem__函数。trainset是一个类,它公有继承自类dataset。说起来是不是有点绕呢?我用一个示意图去解释它们。在图中,Get Data Function是使用者需要定义的函数,比如训练数据读取,数据增广等。

在这里插入图片描述
图1:DataLoader的调用关系图

2.2 mmDetection中的数据输入

上一节讨论了PytorchDataloader类的原理图示,这一节看看mmDetection中数据输入的原理图示。有一点要明确的是,mmDetection的底层调用的还是Dataloader。下面的原理图示以三维目标检测SA-SSD为例子。

在这里插入图片描述
图2:mmDetection初始化DataLoader类的图示

由上图可见,mmDetection首先调用函数get_dataset,根据配置文件cfg.data.train初始化一个数据集处理的类KittiLiDAR(相当于图1中的类trainset),然后再调用函数build_dataloader,使用类KittiLiDAR初始化类Dataloader

上述的流程需要频繁地做类实例化操作,需要用到函数obj_from_dict。初始化一个类需要初始条件(比如img=cam_1flip=True等等),我们把初始条件写进一个字典型变量里面(比如{"img":cam_1,"flip":True,...}),然后调用函数obj_from_dict,输入需要初始化类的类型以及初始化字典型变量,就可以完成这个类的初始化。

3. 结束语

这篇博客讨论了Dataloader类访问训练数据的流程,以及mmDetection调用Dataloader类的流程。

  • 1
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值