PyTorch深度学习——加载数据

本文介绍了PyTorch中Dataset和Dataloader的概念及其作用。Dataset类用于定义数据获取方式,它包含了数据集的组织结构。而Dataloader则将数据转化为适合训练的批次形式。通过__getitem__()方法,可以按索引访问Dataset中的数据项。示例展示了如何创建一个自定义数据集MyData,用于加载和展示图像,并将其与不同标签的数据集合并。
摘要由CSDN通过智能技术生成

Dataset:提供一种获取数据的方式
Dataloader:为后面的网络提供不同的数据形式

__getitem__(self,key):

把类中的属性定义为序列,可以使用__getitem__()函数输出序列属性中的某个元素,这个方法返回与指定键想关联的值。对序列来说,键应该是0~n-1的整数,其中n为序列的长度。对映射来说,键可以是任何类型。

如果在类中定义了__getitem__()方法,那么它的实例对象(假设为P)就可以以P[key]形式取值,当实例对象做P[key]运算时,就会调用类中的__getitem__()方法。当对类的属性进行下标的操作时,首先会被__getitem__() 拦截,从而执行在__getitem__()方法中设定的操作,如赋值,修改内容,删除内容等。
from torch.utils.data import Dataset
from PIL import Image
import os

class MyData(Dataset):
    def __init__(self,root_dir,label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir,self.label_dir)
        self.img_path = os.listdir(self.path)

    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
        img = Image.open(img_item_path)
        img.show()
        label = self.label_dir
        return img,label

    def __len__(self):
        return len(self.img_path)


root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir,ants_label_dir)
bees_dataset = MyData(root_dir,bees_label_dir)

train_dataset = ants_dataset + bees_dataset


  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值