2. PyTorch中数据的读取 - Dataset

Dataset demo

    1. 前置基础知识
      os库的使用
        import os
        dir_path = "dataset/train/ants"
        img_path_list = os.listdir(dir_path)
        #此时img_path_list为文件列表
        img_path_list[0]
    
        >>> 'xxx.jpg'
    
    1. 本文使用的数据集样式
      数据集格式

其中 i m a g e image image文件夹中为jpg格式图片, l a b e l label label文件夹中为txt文本。
一般设置 r o o t _ d i r root\_dir root_dir为所有数据集的根目录, 设置 t r a i n i m g train_img trainimg为数据相对根目录的地址,设置 l a b e l _ d i r label\_dir label_dir为标签相对根目录的地址(名字随意,自己能区分就好)。
使用 o s . p a t h . j o i n ( r o o t _ d i r , l a b e l _ d i r ) os.path.join(root\_dir, label\_dir) os.path.join(root_dir,label_dir)函数将他们拼接起来。

    1. demo
from torch.utils.data import Dataset
import os
from PIL import Image


class MyData(Dataset):

    # 数据类初始化
    def __init__(self, root_dir, train_img_dir_name, label_dir):
        # 定义数据主文件夹地址
        self.root_dir = root_dir
        # 定义标签完整地址(根据不同数据做调整)
        self.label_dir = os.path.join(self.root_dir, label_dir)
        # 定义图片文件夹路径名
        self.train_img_dir_name = train_img_dir_name
        # 地址拼接
        self.path = os.path.join(self.root_dir, self.train_img_dir_name)
        # 数据列表
        self.img_path = os.listdir(self.path)
        self.img_label = os.listdir(self.label_dir)

    # 得到单个数据(图片)
    def __getitem__(self, idx):
        # 从列表中得到单一数据的name
        img_name = self.img_path[idx]
        # 获取标签
        img_label = self.img_label[idx]
        # 得到数据地址
        img_item_path = os.path.join(self.root_dir, self.train_img_dir_name, img_name)
        # 使用PIL中的Image库打开图片
        img = Image.open(img_item_path)
        # 获取标签,从对应的txt文件中读取标签名
        img_item_label = os.path.join(self.label_dir, img_label)
        self.label = ""
        with open(img_item_label, "r") as f:
            self.label = f.read()
        return img, self.label

    # 得到数据大小
    def __len__(self):
        return len(self.img_path)


# 创建一个实例
root_dir = 'dataset/train/'
ants_train_img = 'ants_image'
ants_label_dir = "ants_label"
ants_dataset = MyData(root_dir, ants_train_img, ants_label_dir)

# 创建第二个实例
bees_train_img = 'bees_image'
bees_label_dir = 'bees_label'
bees_dataset = MyData(root_dir, bees_train_img, bees_label_dir)

# 测试是否得到蚂蚁图片
# img, label = ants_dataset[0]
# img.show()
# print(label)
# print(len(ants_dataset))

# 测试是否得到蜜蜂图片
# img, label = bees_dataset[0]
# img.show()
# print(label)
# print(len(bees_dataset))

# 定义完整训练数据集
train_dataset = ants_dataset + bees_dataset
print(len(train_dataset))

# 测试
img, label = train_dataset[124]
img.show()
print(label)
print(len(train_dataset))

PyTorch数据读取是构建深度学习模型的重要一环。为了高效处理大规模数据集,PyTorch提供了三个主要的工具:Dataset、DataLoader和TensorDatasetDataset是一个抽象类,用于自定义数据集。我们可以继承Dataset类,并重写其的__len__和__getitem__方法来实现自己的数据加载逻辑。__len__方法返回数据集的大小,而__getitem__方法根据给定的索引返回样本和对应的标签。通过自定义Dataset类,我们可以灵活地处理各种类型的数据集。 DataLoader是数据加载器,用于对数据集进行批量加载。它接收一个Dataset对象作为输入,并可以定义一些参数例如批量大小、是否乱序等。DataLoader能够自动将数据集划分为小批次,将数据转换为Tensor形式,然后通过迭代器的方式供模型训练使用。DataLoader在数据准备和模型训练的过程起到了桥梁作用。 TensorDataset是一个继承自Dataset的类,在构造时将输入数据和目标数据封装成Tensor。通过TensorDataset,我们可以方便地处理Tensor格式的数据集。TensorDataset可以将多个Tensor按行对齐,即将第i个样本从各个Tensor取出,构成一个新的Tensor作为数据集的一部分。这对于处理多输入或者多标签的情况非常有用。 总结来说,Dataset提供了自定义数据集的接口,DataLoader提供了批量加载数据集的能力,而TensorDataset则使得我们可以方便地处理Tensor格式的数据集。这三个工具的配合使用可以使得数据处理变得更加方便和高效。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值