【自学 PyTorch 】第二课 —— 【代码实战记录】PyTorch 数据集读取完整代码

本文介绍了如何使用PyTorch加载自定义数据集,通过定义一个继承自`Dataset`的类,实现数据的读取。在`__init__`方法中,设置数据路径和标签目录;`__getitem__`方法用于根据索引获取图像和标签;`__len__`返回数据列表长度。实例化类并打印数据集长度,展示了一个简单的数据加载流程。
摘要由CSDN通过智能技术生成

PyTorch 数据集读取完整代码


写一个系列代码实战,争取每天都更。

倒逼自己赶紧提升写 Python 代码的手感。

一、代码

在这里插入图片描述

import os

from PIL import Image
from torch.utils.data import Dataset


class Nemo(Dataset):
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path_list = os.listdir(self.path)

    def __getitem__(self, idx):
        image_name = self.img_path_list[idx]
        image_item_path = os.path.join(self.root_dir, self.label_dir, image_name)
        img = Image.open(image_item_path)
        label = self.label_dir
        return img, label

    def __len__(self):
        return len(self.img_path_list)


root_dir = "D:\\Python_In_One\\Project\\XiaoTuDui\\data\\train"
ants_label_dir = "ants_label"
ants_dataset = Nemo(root_dir, ants_label_dir)

print(len(ants_dataset))

二、理解

如何理解数据的加载呢?

就是说,要想获取自己电脑里的数据,读取它,那么就要遵守 pytorch 加载数据的规则。

他的规则就是 定义一个 类,继承 Dataset (from torch.utils.data import Dataset),并且,在类中,定义三个函数,分别是 初始化 init、获得每一个数据 getitem、数据长度 len。

这里面的过程,要很清楚:

1、 路径、合并路径、把文件夹中的每一个文件名称,做成一个列表(这是init要做的事情);

2、 访问init中的列表,把列表的名称逐一传递给一个变量,命名为name,再次合并路径,并且把文件名连接在路径之后,接下来,用PIL中的Image.open函数,读取(加载)上述路径的文件(命名为img)(这里肯定是图像了),返回 图像img和标签 label(这是getitem的工作);

3、 最后用len()返回列表的长度。

定义好 类 以后,后面就可以实例化这个类,定义参数(本例其实是一个路径,一个夹名称了),名称可以和定义类中的不一样,但是位置要对应(奥,这可能是Python课程里说的位置参数?)。
引用之前定义的类,把上述参数,传递进去。

最后打印自定义数据列表的长度。

参考内容

该案例是 上手学习 PyTorch 时,B站 up 【我是土堆】的代码实战。

提醒自己

在我的文件夹中,

文件名为:P6_7_read_data.py
  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值