Pytorch框架学习记录1——Dataset类代码实战

Pytorch框架学习记录1——Dataset类代码实战

介绍

  • torch.utils.data.Dataset是代表自定义数据集方法的抽象类,你可以自己定义你的数据类继承这个抽象类,非常简单,只需要定义__len____getitem__这两个方法就可以。
  • 通过torch.utils.data.DataLoader类来定义一个新的迭代器,用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入。
这两个抽象类中用到的python知识点

能够熟练的使用python语言的技巧,是理解pytorch源码的关键。在torch.utils.data.Dataset和torch.utils.data.DataLoader这两个类中会用到python抽象类的魔法方法,包括__len__(self),getitem(self)和__iter__(self)

  • __len__(self) 定义当被len()函数调用时的行为(返回容器中元素的个数)
  • __getitem__(self)定义获取容器中指定元素的行为,相当于self[key],即允许类对象可以有索引操作。
  • __iter__(self)定义当迭代容器中的元素的行为

数据集下载地址:https://pan.baidu.com/s/1qNCOVz15mCSQEDZZXJAaoQ?pwd=qz2b
提取码:qz2b

1. 导入包

from torch.utils.data import Dataset
from PIL import Image
import os

2. 创建类

创建子类MyData,继承父类Dataset,并对函数进行重写。

class MyData(Dataset):

    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)  # 路径拼接
        self.img_path = os.listdir(self.path)  # 获得文件夹下所有的文件名

    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir, self.img_path, img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img, label

    def __len__(self):
        return len(self.img_path)

3. 调用

类进行实例化,并进行调用

root_dir = "C:\\Users\\hp\\PycharmProjects\\pythonProject\\Pytorch_Learning\\flower_data\\train"
daisy_label_dir = "daisy"
roses_label_dir = "roses"
daisy_dataset = MyData(root_dir, daisy_label_dir)
roses_dataset = MyData(root_dir, roses_label_dir)

train_dataset = daisy_dataset + roses_dataset
print("daisy:",len(daisy_dataset),"\nroses",len(roses_dataset),"\ndaisy+roses",len(train_dataset))
img, label = train_dataset[0]
img2, label = train_dataset[577]
img2.show()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Yozu_Roo

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值