😎😎😎物体检测-系列教程 总目录
有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码
4、create_dataloader函数
- create_dataloader函数是在train.py的main函数中调用train函数,然后在train函数中调用create_dataloader函数
- create_dataloader函数的位置:yolov5/utils/datasets.py
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
with torch_distributed_zero_first(rank):
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
augment=augment, # augment images
hyp=hyp, # augmentation hyperparameters
rect=rect, # rectangular training
cache_images=cache,
single_cls=opt.single_cls,
stride=int(stride),
pad=pad,
image_weights=image_weights,
prefix=prefix)
batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
# Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
dataloader = loader(dataset,
batch_size=batch_size,
num_workers=nw,
sampler=sampler,
pin_memory=True,
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
return dataloader, dataset
- 定义create_dataloader函数,传入以下参数:
- path,数据集路径
- imgsz,图像尺寸,640*640
- batch_size,默认16
- stride,从输入到输出的降采样的比例
- opt,配置选项
- hyp,超参数,包含学习率、学习率衰减、动量
- augment,是否进行数据增强
- cache,是否缓存图片,每次读入数据时都会使用opencv读取,但是数据集需要反复读取,如果将数据和标签读成一个比如ndarray的缓存大文件,下次迭代训练的时候直接读这个大文件
- pad,卷积的pad填充圈数
- rect,是否使用矩形训练,一般情况下都是将输入数据resize成一个正方形,或者将矩形的宽使用0进行填充使其变成一个正方形,但是这样做会浪费数据,可以考虑直接矩形进行训练
- rank,分布式训练多个进程的索引
- world_size,分布式训练的总进程数
- workers,指定的线程数
- image_weights,是否根据图像权重采样,因为图像的对象类别有多种,如果这个类别数量不均衡的话,可以选择给每种类别一个权重,来保持这个均衡
- quad,是否使用四边形输入
- prefix,前缀字符串,用于指定日志或输出信息的前缀,这有助于在训练过程中区分和标识不同阶段或不同数据集的信息,在使用大规模数据集或进行复杂的训练过程中,可能会有多个dataloader同时在用,例如分别用于训练集、验证集和测试集。这个参数可以为每个数据加载器生成的日志或输出添加一个独特的标识符,查看日志可以轻松识别信息是来自哪个阶段或哪个数据集的
- 如果使用了分布式训练(即多机多卡,或单机多卡),则必须等待当前第1个进行读取完数据后才可以进行进程的并行,如果不是分布式训练(即单机单卡),这行代码相当于没有
- dataset ,创建LoadImagesAndLabels类的数据集实例,其中传入的参数在1中大部分已解释,single_cls,如果是单类别分类会设置这个分类标签为0
- batch_size ,如果样本总数比batch_size 还小,将batch_size 调整为样本总数
- nw ,基于CPU核心数、batch_size 、和指定的工作线程数来计算用于数据加载的工作线程数
- sampler ,如果在分布式训练环境中,则创建一个 DistributedSampler 用于数据采样,否则不使用采样器
- loader ,如果按照图像权重采样使用标准的DataLoader ,否则使用自定义的 InfiniteDataLoader
- dataloader ,创建dataloader实例
- 其中pin_memory是内存锁定,当设置为True时,dataloader会将数据加载到CPU的固定内存中,即保持在物理内存中,不会被交换到磁盘上,这可以减少访问这些内存时的延迟,可以加快将数据移动到GPU的速度
- collate_fn是一个函数,如果quad 为True使用LoadImagesAndLabels.collate_fn4(是一个标准的批处理函数,将图像和标签组合成批次)作为collate_fn,否则使用LoadImagesAndLabels.collate_fn(处理四边形标注的数据)
- 返回制作好的dataloader