Pytorch-DataLoader 和 Dataset

参考:
https://blog.csdn.net/zw__chen/article/details/82806900

https://blog.csdn.net/Threelights/article/details/88680540

一 要点总结

  • 1 torch.utils.data.Dataset是一个包装类,用来将数据包装为Dataset类,然后传入DataLoader中。
  • 2 将数据包装为Dataset类有两种方式:一种是用 torch.utils.data.Dataset.TensorDataset 来将数据包装成Dataset类,一种是自己写一个继承 torch.utils.data.Dataset类 ,实现类中的 len 方法和getitem 方法
  • 3 torchvision.datasets中的所有数据集都是 torch.utils.data.dataset 的子类,因而都实现了 getitemlen 方法,可以直接传递给 torch.utils.data.dataloader

二 示例

1 用 torch.utils.data.Dataset.TensorDataset 来将数据包装成Dataset类

import h5py
import numpy as np
from torch.utils.data import DataLoader, TensorDataset, Dataset

# 查看数据

train_dataset = h5py.File('datasets/train_signs.h5', "r")
train_dataset
# out: <HDF5 file "train_signs.h5" (mode r)>
train_dataset.keys()
# out: <KeysViewHDF5 ['list_classes', 'train_set_x', 'train_set_y']>
list_classes,train_set_x,train_set_y  = train_dataset['list_classes'],train_dataset['train_set_x'],train_dataset['train_set_y']
list_classes, train_set_x, train_set_y
# out: (<HDF5 dataset "list_classes": shape (6,), type "<i8">,
# out: <HDF5 dataset "train_set_x": shape (1080, 64, 64, 3), type "|u1">,
# out: <HDF5 dataset "train_set_y": shape (1080,), type "<i8">)

# 用 torch.utils.data.Dataset.TensorDataset 来将数据包装成Dataset类

x_train = np.array(train_dataset["train_set_x"][:]) # your train set features
x_train = np.transpose(x_train, (0, 3, 1, 2))
y_train = np.array(train_dataset["train_set_y"][:]) # your train set labels
y_train = y_train.reshape((1, y_train.shape[0])).T
X_train_tensor = torch.tensor(x_train, dtype=torch.float)/255
Y_train_tensor = torch.tensor(y_train, dtype=torch.long)
train_dataset = TensorDataset(X_train_tensor, Y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

for epoch in range(2):
    for i,data in enumerate(train_loader):
        inputs,labels = data
        print(epoch,i,"inputs",inputs.size(),"labels",labels.size())

在这里插入图片描述

2 继承 torch.utils.data.Dataset类 ,实现类中的 len 方法和getitem 方法

# 继承 Dataset类 
# 重写 len 方法,该方法提供了dataset的大小; getitem 方法, 该方法支持从 0 到 len(self)的索引
class TrainDataset(Dataset):
    """
        下载数据、初始化数据,都可以在这里完成
    """
    def __init__(self):
        train_dataset = h5py.File('datasets/train_signs.h5', "r") # 读取数据
        x_train,y_train = np.array(train_dataset["train_set_x"][:]),np.array(train_dataset["train_set_y"][:]) # 读取数据为array
        x_train,y_train = np.transpose(x_train, (0, 3, 1, 2)),y_train.reshape((1, y_train.shape[0])).T # 改变array格式
        X_train_tensor,Y_train_tensor = torch.tensor(x_train, dtype=torch.float)/255,torch.tensor(y_train, dtype=torch.long) # 将array转换为tensor
        self.x_data =  X_train_tensor
        self.y_data = Y_train_tensor
    
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return x_train.shape[0] 
#  实例化这个类,就得到Dataset类型的数据,将这个类传给DataLoader
trainDataset =  TrainDataset()
type(trainDataset)
# out: __main__.TrainDataset
len(trainDataset)
# out: 1080
train_loader2 = DataLoader(dataset=trainDataset,
                          batch_size=64,
                          shuffle=True)

for epoch in range(2):
    for i, data in enumerate(train_loader2):
        inputs, labels = data
        # 接下来就是跑模型的环节了,我们这里使用print来代替
        print("epoch:", epoch, "的第" , i, "个inputs", inputs.size(), "labels", labels.size())

在这里插入图片描述

3 torchvision.datasets.mnist使用示例

from torchvision.datasets import mnist
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

# 查看数据

train_dataset = mnist.MNIST(root='./train', train=True)
train_dataset
# out: Dataset MNIST
# out:    Number of datapoints: 60000
# out:    Root location: ./train
# out:    Split: Train
type(train_dataset)
# out: torchvision.datasets.mnist.MNIST
len(train_dataset)
# out: 60000
train_dataset[0]
# out: (<PIL.Image.Image image mode=L size=28x28 at 0x1323FD26D88>, 5)
train_dataset[0][1]
# out: 5

# 将这个类传给DataLoader

train_loader = DataLoader(dataset=train_dataset,batch_size=64,shuffle=True)
for epoch in range(2):
    for i, data in enumerate(train_loader):
        inputs, labels = data
        # 接下来就是跑模型的环节了,我们这里使用print来代替
        print("epoch:", epoch, "的第" , i, "个inputs", inputs.size(), "labels", labels.size())
# TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>
# 必须是tensors, numpy arrays, numbers, dicts or lists,这里是class 'PIL.Image.Image'

# DataLoader承接的Dataset类必须是tensors, numpy arrays, numbers, dicts or lists,因而用transform=ToTensor()将其转换为tensor
train_dataset = mnist.MNIST(root='./train', train=True,transform=ToTensor())
train_dataset
# out: Dataset MNIST
# out:    Number of datapoints: 60000
# out:    Root location: ./train
# out:    Split: Train
# out:    StandardTransform
# out:Transform: ToTensor()
type(train_dataset)
# out: torchvision.datasets.mnist.MNIST
len(train_dataset)
# out: 60000
train_dataset[0]
# out: 

在这里插入图片描述

train_loader = DataLoader(dataset=train_dataset,batch_size=64,shuffle=True)

for epoch in range(2):
    for i, data in enumerate(train_loader):
        inputs, labels = data
        # 接下来就是跑模型的环节了,我们这里使用print来代替
        print("epoch:", epoch, "的第" , i, "个inputs", inputs.size(), "labels", labels.size())

在这里插入图片描述

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: 如果你是在问 PyTorch 中的数据集和数据加载器,那么我很愿意为您解答。 PyTorch 是一个开源深度学习框架,其中包含了用于加载和预处理数据的工具。其中最重要的两个组件是数据集 (Dataset) 和数据加载器 (DataLoader)。 数据集是一个 PyTorch 类,它定义了如何读取数据、如何访问数据以及如何将数据转换为张量。您可以使用内置的数据集类,例如 torchvision.datasets 中的 ImageFolder,或者自定义数据集类。 数据加载器是一个 PyTorch 类,它可以从数据集中读取数据并将其批量加载到内存中。数据加载器可以进行并行加载,以提高加载速度,并且可以通过 shuffle 参数来随机打乱数据。 举个例子,如果您有一个图像数据集,您可以使用以下代码来创建数据集和数据加载器: ``` import torch import torchvision # 创建数据集 dataset = torchvision.datasets.ImageFolder(root='path/to/data', transform=transforms.ToTensor()) # 创建数据加载器 dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4) ``` 然后,您可以使用以下代码来读取数据: ``` for inputs, labels in dataloader: # 处理输入数据 ... ``` 希望对您有所帮助! ### 回答2: PyTorch是一种广泛使用的深度学习框架,具有易于使用的API和优秀的性能。其中,Dataset和DataLoader是两个非常重要的类,它们可以帮助我们有效地加载和处理数据。 Dataset是一个抽象的概念,用于表示一组数据。我们可以继承它并重写其中的方法,以实现对不同数据集的适配。在初始化时,我们需要传递一个数据集,比如说图片数据集,然后在DataLoader中使用这个数据集,实现数据的准备和加载。在自定义Dataset时,我们需要定义__getitem__和__len__两个方法,分别用于返回数据集中的某个数据和数据总数。 DataLoader是一个非常实用的工具,用于加载数据并把数据变成可迭代的对象,其中包含了批量大小、数据是否随机等设置。我们可以设置num_workers参数,用多个进程来读取数据提高读取数据的速度。通过使用DataLoader,我们可以很方便地迭代整个数据集,可以按批次加载和处理数据。 当我们使用在线学习时,经常需要不断地读取数据并进行训练。在应用中,我们会遇到许多不同的数据集,其中可能包含不同的数据类型,比如图像、音频、文本等。使用Dataset和DataLoader类,我们可以轻松处理这些数据,从而使我们的深度学习应用具有更广泛的适用性和扩展性。 总之,Dataset和DataLoader是PyTorch中非常重要的类,它们可以帮助我们非常方便地进行数据的处理和加载。无论你想要使用哪种数据集,它们都能够很好地适配。在实际应用中,我们可以灵活地使用这两个类来加载和准备数据并进行训练,从而加快应用的速度并提高深度学习的精度。 ### 回答3: PyTorch是一个流行的深度学习框架,它提供了Dataset和DataLoader这两个类来帮助我们更方便地处理数据。 Dataset可以看作是一个数据集,它定义了如何读取数据。官方提供了两种Dataset:TensorDataset和ImageFolder。TensorDataset是用来处理张量数据,而ImageFolder则是用来处理图像数据。如果我们需要使用其他类型的数据,我们可以通过重写Dataset类中的__getitem__和__len__方法来实现。 在实现Dataset之后,我们需要将数据读取到内存中,在模型训练时提供给模型,这时我们就需要使用到DataLoader了。DataLoader可以看作是一个数据加载器,它会自动将Dataset中的数据批量读取到内存中,并且支持数据的分布式加载。 在使用DataLoader时我们可以设置很多参数,比如batch_size表示每个batch的大小,shuffle表示是否打乱数据顺序,num_workers表示使用多少线程读取数据等等。这些参数都可以帮助我们更好地利用硬件资源,提高训练速度和效率。 使用PyTorch的Dataset和DataLoader可以帮助我们更方便快捷地处理数据,并且让我们可以更专注于模型的设计和训练。但我们也要注意一些细节问题,比如数据读取是否正确、内存使用是否合理等等。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值