目录
视频教程:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】
一、数据加载与编号
在机器学习和深度学习任务中,数据加载是一个非常重要的步骤。在许多情况下,我们需要将数据加载到模型中,同时还需要获取与数据相关的标签。Dataset
类在PyTorch中提供了两个主要功能:
- 如何获取每一个数据以及其标签。
- 告诉我们总共有多少个数据点。
这两个功能是数据加载的关键部分,而Dataset
类帮助我们轻松实现这些功能。
二、数据集的组织方式
数据集的组织方式有多种形式,但在本文中,我们将关注两种常见的方式:
- 文件夹的名称即为数据的标签。
- 文件名和标签分别位于两个不同的文件夹中,标签可以使用文本文件(如txt)进行存储。
三、使用PyTorch的Dataset类
- 导入必要的的库
import os from PIL import Image from torch.utils.data import Dataset
__init__(self, root_dir, label_dir)
方法用于初始化数据集对象。它接受两个参数:root_dir
是数据集的根目录,label_dir
是与数据标签相关的子目录。class MyData(Dataset): def __init__(self, root_dir, label_dir): self.root_dir = root_dir self.label_dir = label_dir self.path = os.path.join(self.root_dir, self.label_dir) self.img_path = os.listdir(self.path)
__getitem__(self, idx)
方法用于获取数据和其标签。它接受一个索引idx
,通过该索引获取特定数据点的图像和标签。def __getitem__(self, idx): img_name = self.img_path[idx] img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) img = Image.open(img_item_path) label = self.label_dir return img, label
__len__(self)
方法返回数据集的长度,也就是数据点的总数。def __len__(self): return len(self.img_path)
- 完整代码如下所示:
# 导入必要的库 import os from PIL import Image from torch.utils.data import Dataset # 创建一个自定义的数据集类 MyData,继承自 PyTorch 的 Dataset 类 class MyData(Dataset): # 初始化方法,用于设置数据集的基本信息 def __init__(self, root_dir, label_dir): # 存储数据集根目录和标签子目录 self.root_dir = root_dir self.label_dir = label_dir # 拼接数据路径,组成完整的路径 self.path = os.path.join(self.root_dir, self.label_dir) # 获取指定标签子目录下的所有文件名 self.img_path = os.listdir(self.path) # 获取数据点的方法,根据索引(idx)返回图像和标签 def __getitem__(self, idx): # 获取特定索引的图像文件名 img_name = self.img_path[idx] # 拼接图像文件的完整路径 img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 使用 PIL 库打开图像文件 img = Image.open(img_item_path) # 标签是当前数据点所在的标签子目录名称 label = self.label_dir # 返回图像和标签 return img, label # 获取数据集长度的方法,返回数据点的总数 def __len__(self): return len(self.img_path) # 在程序的入口点运行以下代码 if __name__ == '__main__': # 定义数据集的根目录 root_dir = "dataset/train" # 分别指定两个不同的标签子目录 ants_label_dir = "ants_image" bees_label_dir = "bees_image" # 创建两个数据集对象,分别用于加载不同标签的数据 ants_dataset = MyData(root_dir, ants_label_dir) bees_dataset = MyData(root_dir, bees_label_dir) # 合并两个数据集以创建一个用于训练的数据集 train_dataset = ants_dataset + bees_dataset