dataloader介绍
dataloader
概述
参考博客
DataLoader
是深度学习中重要的数据处理工具之一,旨在有效加载、处理和管理大规模数据集,用于训练和测试机器学习和深度学习模型。
DataLoader
是一个用于批量加载数据的工具,它可以将数据集分成多个小批量(mini-batch)
,并逐个加载,以适应模型训练的需要。
DataLoader
主要用于两个关键任务:数据加载和批次处理
- 数据加载:
DataLoader
可以从不同来源加载数据,如硬盘上的文件、数据库、网络等。它能够自动将数据集划分为小批次,从而减小内存需求,确保数据的高效加载。 - 数据批次处理:每个批次由多个样本组成,可以并行地进行数据预处理和数据增强。这有助于提高模型训练的效率,同时确保每个批次的数据都经过适当的处理。
collate_fn
collate_fn 是一个自定义函数,用于在 PyTorch 的 DataLoader 中定义如何将单个样本组合成一个批次(batch)。具体来说,collate_fn 函数会在每次从 DataLoader 中取出一个批次的数据时被调用,用于对数据进行整理和转换。
主要作用
collate_fn
:返回值为最终构建的batch数据;在这一步中处理dataset的数据,将其调整成期望的数据格式。
将一个批次的数据样本整理成适合模型输入的格式,特别是将数据转换为 PyTorch 张量(Tensor),以便于后续的模型训练和推理。
- 自定义数据堆叠:将单个样本组合成一个批次,处理数据的不同形状或类型。
- 数据转换:在批次数据组成之前进行必要的转换操作,例如数据类型转换、数据增强等。
在代码中的使用
在本代码中,unet_dataset_collate 函数就是一个 collate_fn 函数。它的作用是将一个批次的数据样本(图像、PNG 数据和分割标签)整理成适合模型输入的格式。具体步骤包括将数据从列表转换为 NumPy 数组,再转换为 PyTorch 张量。
代码详解
# DataLoader中collate_fn使用
def unet_dataset_collate(batch):
images = []
pngs = []
seg_labels = []
for img, png, labels in batch:
images.append(img)
pngs.append(png)
seg_labels.append(labels)
images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
pngs = torch.from_numpy(np.array(pngs)).long()
seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor)
return images, pngs, seg_labels
这段代码定义了一个名为 unet_dataset_collate
的函数,用于在 PyTorch 的 DataLoader 中自定义批处理方式。函数将一个批次的数据样本(batch)转换为适合模型输入的格式。
代码解释
__init__函数
在 DataLoader 中,init 函数的主要作用是初始化数据集对象,并为后续的数据加载和处理做好准备。
UnetDataset 类的 init 函数在 DataLoader 中的作用包括:
- 数据集初始化:通过传入的参数(如 annotation_lines、input_shape 等)初始化数据集对象,使其包含所有必要的信息。
- 数据预处理:在初始化过程中,可以对数据进行预处理,如归一化、裁剪等,以便后续的模型训练。
- 数据分割:将数据集分割成训练集和验证集(通过 train 参数),以便在训练过程中进行模型评估。
- 路径管理:通过 dataset_path 参数指定数据集的存储路径,方便数据的加载和管理。
# UnetDataset 类的初始化方法,接受五个参数:annotation_lines、input_shape、num_classes、train 和 dataset_path。
def __init__(self, annotation_lines, input_shape, num_classes, train, dataset_path):
# super() 函数用于调用父类的初始化方法。在这里,它调用了 UnetDataset 类的父类的 __init__ 方法,确保父类的初始化逻辑也被执行。这对于继承自其他类的类非常重要。
super(UnetDataset, self).__init__()
# self 代表类的实例。self.annotation_lines 将传入的 annotation_lines 参数赋值给实例属性 annotation_lines
self.annotation_lines = annotation_lines
self.length = len(annotation_lines)
self.input_shape = input_shape
self.num_classes = num_classes
self.train = train
self.dataset_path = dataset_path
解释 super 和 self
- super
super()
函数用于调用父类的方法。在多重继承的情况下,它确保正确调用父类的方法,避免重复调用。这里,它调用了 UnetDataset 类的父类的 init 方法。 - self
self
是类的实例的引用。它用于访问类的属性和方法。在类的方法中,self 必须作为第一个参数传递,以便方法能够访问实例的属性和其他方法。
collate_fn
# DataLoader中collate_fn使用
# 函数定义:net_dataset_collate(batch):定义了一个函数,接收一个批次的数据样本batch。
def unet_dataset_collate(batch):
# 初始化列表:
# images = []:用于存储所有图像数据。
# pngs = []:用于存储所有 PNG 格式的数据。
# seg_labels = []:用于存储所有分割标签数据
images = []
pngs = []
seg_labels = []
# 遍历批次数据:
# 遍历批次中的每个样本,假设每个样本包含图像、PNG 数据和分割标签。
# images.append(img):将图像数据添加到 images 列表中。
# pngs.append(png):将 PNG 数据添加到 pngs 列表中。
# seg_labels.append(labels):将分割标签数据添加到 seg_labels 列表中。
for img, png, labels in batch:
images.append(<