Pytorch加载自己的数据集(使用DataLoader加载Dataset)

https://www.pytorchtutorial.com/pytorch-custom-dataset-examples/

https://blog.csdn.net/l8947943/article/details/103733473

1. 我们需要加载自己的数据集,使用Dataset和DataLoader

  • Dataset:是被封装进DataLoader里,实现该方法封装自己的数据和标签。
  • DataLoader:被封装入DataLoader迭代器里,实现该方法达到数据的划分。

2.Dataset

主要继承该方法必须实现两个方法:

  • _getitem_()
  • _len_()
import torch
import numpy as np


# 定义GetLoader类,继承Dataset方法,并重写__getitem__()和__len__()方法
class GetLoader(torch.utils.data.Dataset):
	# 初始化函数,得到数据
    def __init__(self, data_root, data_label):
        self.data = data_root
        self.label = data_label
    # index是根据batchsize划分数据后得到的索引,最后将data和对应的labels进行一起返回
    def __getitem__(self, index):
        data = self.data[index]
        labels = self.label[index]
        return data, labels
    # 该函数返回数据大小长度,目的是DataLoader方便划分,如果不知道大小,DataLoader会一脸懵逼
    def __len__(self):
        return len(self.data)

# 随机生成数据,大小为10 * 20列
source_data = np.random.rand(10, 20)
# 随机生成标签,大小为10 * 1列
source_label = np.random.randint(0,2,(10, 1))
# 通过GetLoader将数据进行加载,返回Dataset对象,包含data和labels
torch_data = GetLoader(source_data, source_label)

3.DataLoader

提供对Dataset的操作,操作如下:

torch.utils.data.DataLoader(dataset,batch_size,shuffle,drop_last,num_workers)

数含义如下:

  • dataset: 加载torch.utils.data.Dataset对象数据
  • batch_size: 每个batch的大小
  • shuffle:是否对数据进行打乱
  • drop_last:是否对无法整除的最后一个datasize进行丢弃
  • um_workers:表示加载的时候子进程数,一般GPU使用

因此,在实现过程中我们测试如下(紧跟上述用例):

from torch.utils.data import DataLoader

# 读取数据
datas = DataLoader(torch_data, batch_size=6, shuffle=True, drop_last=False, num_workers=2)

此时,我们的数据已经加载完毕了,只需要在训练过程中使用即可。

4.查看数据

我们可以通过迭代器(enumerate)进行输出数据,测试如下:

for i, data in enumerate(datas):
	# i表示第几个batch, data表示该batch对应的数据,包含data和对应的labels
    print("第 {} 个Batch \n{}".format(i, data))

5.使用自己保存的“npy”数据集进行加载

定义一个继承dataset的类

import numpy as np
from torch.utils.data.dataset import Dataset
import torch

# 定义CustomDataset类,继承Dataset方法,并重写__getitem__()和__len__()方法
class CustomDataset(torch.utils.data.Dataset):
    # 初始化函数,得到数据
    def __init__(self, pathData, pathLabel):
        self.data = np.load(pathData)  # 传入了dataset X的路径,并使用np.load进行加载数据
        self.label = np.load(pathLabel)  # 传入了label Y的路径

    # index是根据batchsize划分数据后得到的索引,最后将data和对应的labels进行一起返回
    def __getitem__(self, index):
        data = self.data[index]
        labels = self.label[index]
        return data, labels

    # 该函数返回数据大小长度,目的是DataLoader方便划分,如果不知道大小,DataLoader会一脸懵逼
    def __len__(self):
        return len(self.data)

加载数据

我自己的数据集格式:

from torch.utils.data import DataLoader
from CustomDataset import CustomDataset
pathX = './datasetXPro.npy'
pathY = './datasetYPro.npy'
torch_data = CustomDataset(pathX, pathY)
# 读取数据
datas = DataLoader(torch_data, batch_size=6, shuffle=True, drop_last=False, num_workers=2)

for i, data in enumerate(datas):
    # i表示第几个batch, data表示该batch对应的数据,包含data和对应的labels
    print("第 {} 个Batch \n{}".format(i, data))

 

 

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值