一、简介
1. torch中的dataloader:
torch.utils.data.DataLoader(dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None, *,
prefetch_factor=2,
persistent_workers=False)
2. 参数解释:
dataset:
data和labels,一般继承torch.utils.data.Dataset
,可在其中做数据增广等处理。batch_size:
default=1shuffle:
default=False,是否打乱数据集。sampler:
定义采样策略,若自己定义,则shuffle
要设定False。如使用难采样三元组损失时,就需要在一个batch内对当前样本进行自定义的采样规则。batch_sampler:
和sampler
类似,但是返回的是一个batch的index,与shuffle
,sampler
,drop_last
互斥num_workers:
default=0,注意windows中使用时候一般设置为0,要不然会出错或者速度更慢collate_fn:
将一个batch内的imgs和labels合并,如果只返回img和label,那么可以使用默认的collate_fn
,但是如果返回img box label
,每一个img的box数目不一定相同,所以就需要在这个函数里面加入当前box属于当前batch的哪一张图片,就需要自定义collate_fn
将对应的数据合并成一个batch。再如,使用mosaic
数据增强时,需要将处理好的4张图片拼接起来,同样对应的label也要拼接。pin_memory:
内存大就开,不大就不开。表示要将load进来的数据是否要拷贝到pin_memory区中,其表示生成的Tensor数据是属于内存中的锁页内存区,这样将Tensor数据转义到GPU中速度就会快一些,默认为False。通常情况下,数据在内存中要么以锁页的方式存在,要么保存在虚拟内存(磁盘)中,设置为True后,数据直接保存在锁页内存中,后续直接传入cuda;否则需要先从虚拟内存中传入锁页内存中,再传入cuda,这样就比较耗时了,但是对于内存的大小要求比较高。drop_last:
最后的数据不够一个batch_size时,是否选择舍去。默认为Falsetimeout:
default=0,一般不用管worker_init_fn:
一般不用管prefetch_factor:
一般不用管persistent_workers:
一般不用管
3. 思路:
常见的需要自己自定义按照自己需要重写参数有dataset
、sampler
、collate_fn
二、构造过程
1. 思路:
先将读取的dataset
函数进行重写(包括数据增强、矩形推理等都在这个部分,要注意当图像坐标发生变化时对应的label box的坐标也要对应的进行处理,最好将输出后的img和box坐标画出来,看看是否处理错误),以及采样规则sampler
,batch中的样本结合处理。
1. 构建dataset
:
from torch.utils.data import Dataset
class LoadImgsAndLabels(Dataset):
def __init__(self, path, img_szie=640, batch_size=16, argument=False, hyp=None):
'''
一般情况下,这里有图片路径(label路径可根据图片路径对应的改成自己的文件路径)
img_size,图像一般裁剪到统一大小进行输入
argument,是否使用数据增强
hyp,处理过程中的超参列表,如数据增强中图像上下左右翻转的概率,有多大的概率进行mosaic等等
还有其他需要的一些参数需要自定义
'''
# 1. 获得path下所有图片的绝对路径 self.img_files: []
'''
这里使用 try except Exception as e 判断是否加载数据错误之类的
将path转换为 pathlib.Path 可以生成与os无关的分隔符,可以结合os.sep使用。
用glob.glob和设置一个所有图片或者视频后缀列表进行判断, 筛选出所有的图片或者需要的文件。
'''
# 2. 根据获得的图片路径self.img_files 转换为 标签路径:self.label_files
'''
写个转换函数 img2label_paths
self.label_files = img2label_paths(self.img_files)
'''
def __len__(self):
'''
返回当前数据集的长度(有多少张图片)
return len(self.img_files)
'''
pass
def __getitem__(self, index):
'''
这部分包含读取数据,数据增强, 一般一次性执行batch_size次
可分为训练和测试
eg:
训练 数据增强: mosaic(random_perspective) + hsv + 上下左右翻转
测试 数据增强: letterbox
'''
# 1. 读取图片和labels
'''
常规:使用index读取图片和labels
数据增强:不同的数据增强可以写不同的读取图片的函数
如:img, labels = load_mosaic(self, index)
img, (h0, w0), (h, w) = load_image(self, index) + letterbox + labels对应的处理
'''
pass
2. 构建sampler
:
遇到再写,一般在使用三元组损失之类的loss时候需要重写,比如图片匹配、行人重识别任务中使用。
或者是分布式采样:
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
3. 构建collate_fn
:
def collate_fn(batch):
"""
pytorch的DataLoader打包一个batch的数据集时要经过此函数进行打包 通过重写此函数实现标签与图片对应的划分,
一个batch中哪些标签属于哪一张图片,形如
[[0, 6, 0.5, 0.5, 0.26, 0.35],
[0, 6, 0.5, 0.5, 0.26, 0.35],
[1, 6, 0.5, 0.5, 0.26, 0.35],
[2, 6, 0.5, 0.5, 0.26, 0.35],]
前两行标签属于第一张图片, 第三行属于第二张。。。
"""
img, label, path, shapes = zip(*batch) # transposed
for i, l in enumerate(label):
l[:, 0] = i # add target image index for build_targets()
# 这里之所以拼接的方式不同是因为img拼接的时候它的每个部分的形状是相同的,都是[3, 736, 736]
# 而label的每个部分的形状是不一定相同
# 如果每张图的目标个数是相同的,那我们就可能不需要重写collate_fn函数了
return torch.stack(img, 0), torch.cat(label, 0), path, shapes
4. 创建最后的dataloader
:
def create_dataloader(path, imgsz, batch_size, hyp=None, augment=False):
dataset = LoadImagesAndLabels(path, imgsz, batch_size, augment=augment, hyp=hyp)
loader = torch.utils.data.DataLoader
dataloader = loader(dataset,
batch_size=batch_size,
num_workers=nw,
sampler=sampler,
pin_memory=True,
collate_fn=LoadImagesAndLabels.collate_fn)
return dataloader, dataset