Yolov5源码详解:数据加载篇(上)

第二篇:数据加载(上)

1-代码概述

utils/dataloaders.py是用于加载数据并创建数据加载器的工具类

  • 它实现了一系列函数和类,包括LoadImagesLoadStreamsLoadImagesAndLabelsHUBDatasetStats等。这些类和函数用于从文件系统加载图像和标签数据,并将其转换为模型可用的数据格式。
  • 同时实现了相关的数据加载器类,如InfiniteDataLoaderClassificationDataset,以便将数据加载到模型中进行训练和验证。
  • 还提供了一些辅助函数用于处理图像和标签数据

2-逐模块详解

2-1 InfiniteDataLoader

InfiniteDataLoader类是PyTorch框架中的DataLoader类的子类。DataLoader类是用于加载数据集的主要工具,它可以自动地把数据集划分成小批量(batch),并且可以在加载数据时使用多线程并行处理。

InfiniteDataLoader类重写了父类DataLoader的几个方法,实现了一个可以无限循环加载数据的数据加载器。具体来说,这个类的主要功能和工作方式如下:

__init__方法:这是类的初始化方法。在这个方法里,首先调用了父类DataLoader的初始化方法,然后用一个名为_RepeatSampler的对象替换了DataLoader的batch_sampler属性,这样可以使数据加载器在取尽所有样本后可以再从头开始取样本。最后,创建了一个迭代器self.iterator,用于在后面的方法中生成数据。

__len__方法:这个方法返回了数据集中的样本数量。

__iter__方法:这个方法定义了数据生成的方式。它利用前面创建的迭代器self.iterator,在每次循环时生成一个新的数据批次。由于self.iterator使用了_RepeatSampler作为批次采样器,所以当所有的数据都被取尽时,它会自动地从头开始,从而实现数据的无限循环。

2-2 LoadImagesAndLabels
  • 这类定义了一个用于加载图像和标签数据的数据集类LoadImagesAndLabels,用于训练和验证YOLOv5模型。该类继承自torch.utils.data.Dataset,并实现了__init____len____getitem__等方法。

    方法:

    • __len__(self):返回数据集的图像数量。
    • __getitem__(self, index):获取指定索引的图像和标签。
    • load_image(self, i):加载指定索引的图像。
    • cache_images_to_disk(self, i):将图像缓存到磁盘中。
    • load_mosaic(self, index):加载mosaic数据增强的图像和标签。
    • load_mosaic9(self, index):加载9-mosaic数据增强的图像和标签。
    • collate_fn(self, batch):将批次中的图像和标签进行组合。
    • collate_fn4(self, batch):将4-mosaic批次中的图像和标签进行组合。
2-3 create_dataloader
功能描述

create_dataloader创建一个数据加载器(DataLoader),用于在训练深度学习模型时加载数据集。它接受一个数据集的路径和各种参数,然后返回一个可以用于批量加载和处理该数据集的数据加载器。

以下是其处理步骤:

  1. 在分布式数据并行(DDP)的环境中,首先通过 torch_distributed_zero_first(rank) 确保数据集的 *.cache 文件只被初始化一次。
  2. 创建了一个 LoadImagesAndLabels 对象,该对象负责加载和处理图像及其相关的标签。这个过程中,会根据提供的各种参数(如是否进行数据增强,是否使用矩形批处理,是否缓存图像等)来进行相应的处理。
  3. 计算了数据加载器的批处理大小和工作线程的数量。批处理大小决定了每次加载的数据数量,而工作线程的数量则决定了数据加载的并行性。
  4. 根据是否使用分布式数据并行(DDP)以及是否进行数据shuffle,创建了一个相应的采样器(sampler)。
  5. 创建了一个数据加载器(DataLoader)。如果设置了 image_weightsTrue,则使用 DataLoader,否则使用 InfiniteDataLoader。数据加载器负责按照设定的批处理大小和采样器,批量加载和处理数据。
  6. 设置了一个随机数生成器,以确保在不同的训练过程中,数据加载的顺序是一致的。
  7. 最后,返回创建的数据加载器和数据集对象。

同时还使用了一些全局变量,如 LOGGER (用于记录日志), RANKPIN_MEMORY (用于处理分布式数据并行)。

输入
  • path:数据集路径,可以是包含图像文件的文件夹路径,也可以是包含图像文件路径的列表。
  • imgsz:图像尺寸,用于将图像调整为统一尺寸。
  • batch_size:批大小,每个批次包含的图像数量。
  • stride:图像步长,用于计算特征图大小。
  • single_cls:是否进行单类别训练。
  • hyp:超参数字典。
  • augment:是否进行数据增强。
  • cache:是否将图像缓存到内存或磁盘中加快训练速度。
  • pad:图像填充比例。
  • rect:是否使用矩形训练。
  • rank:当前进程的排名。
  • workers:数据加载器的工作线程数。
  • image_weights:是否使用图像权重。
  • quad:是否使用四通道图像。
输出
  • loader:数据加载器对象。
  • dataset:数据集对象。
示例代码
loader, dataset = create_dataloader(path='data/images', imgsz=640, batch_size=16, stride=32)

3-逐行注释

3-1 InfiniteDataLoader
class InfiniteDataLoader(dataloader.DataLoader): # 定义一个名为InfiniteDataLoader的类,该类继承自dataloader.DataLoader

    def __init__(self, *args, **kwargs): # 定义构造函数,接受任意数量的位置参数和关键字参数
        super().__init__(*args, **kwargs) # 调用父类的构造函数,传入同样的参数
        object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) # 设置self.batch_sampler属性为_RepeatSampler对象,传入self.batch_sampler作为参数
        self.iterator = super().__iter__() # 设置self.iterator为父类的迭代器

    def __len__(self): # 定义长度函数,返回batch_sampler.sampler的长度
        return len(self.batch_sampler.sampler)

    def __iter__(self): # 定义迭代器函数
        for _ in range(len(self)): # 迭代self的长度次数
            yield next(self.iterator) # 每次迭代返回self.iterator的下一个元素
3-2 create_dataloader
# 定义一个名为create_dataloader的函数,功能是创建一个数据加载器,该函数接受多个参数,包括数据路径、图像大小、批处理大小、步幅等
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0, rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix='', shuffle=False, seed=0):
    if rect and shuffle: # 如果rect为True且shuffle为True
        LOGGER.warning('WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False') # 输出警告信息
        shuffle = False # 将shuffle设置为False
    with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
        dataset = LoadImagesAndLabels( # 创建LoadImagesAndLabels对象,传入参数
            path, # 数据集路径
            imgsz, # 图像尺寸
            batch_size, # batch大小
            augment=augment,  # 是否进行数据增强
            hyp=hyp,  # 超参数
            rect=rect,  # 是否使用矩形batch
            cache_images=cache, # 是否缓存图像
            single_cls=single_cls, # 是否进行单类别训练
            stride=int(stride), # 步长
            pad=pad, # 填充
            image_weights=image_weights, # 是否使用图像权重
            prefix=prefix) # 前缀
	# 确定批处理大小,不能超过数据集长度
    batch_size = min(batch_size, len(dataset)) 
    # 获取CUDA设备数量
    nd = torch.cuda.device_count()  
    # 计算工作进程数量,不能超过CPU核数,批处理大小,以及设置的工作进程数
    nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) 
    # 如果是分布式环境,则创建一个分布式采样器,否则设置为None
    sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
    # 如果使用image_weights,则使用DataLoader加载数据,否则使用InfiniteDataLoader加载数据,这两者的区别在于DataLoader允许对属性进行更新
    loader = DataLoader if image_weights else InfiniteDataLoader
    generator = torch.Generator() # 随机数生成器
    generator.manual_seed(6148914691236517205 + seed + RANK) # 设置随机数种子
    return loader(dataset, # 返回DataLoader或InfiniteDataLoader对象
                  batch_size=batch_size, # batch大小
                  shuffle=shuffle and sampler is None, # 是否进行shuffle
                  num_workers=nw, # worker数量
                  sampler=sampler, # 采样器
                  pin_memory=PIN_MEMORY, # 是否将数据存储在固定内存中
                  collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn, # 数据集合并函数
                  worker_init_fn=seed_worker, # worker初始化函数
                  generator=generator), dataset # 返回DataLoader或InfiniteDataLoader对象和数据集对象
  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

upDiff

你的鼓励将是我创作的最大动力!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值