pytorch加载数据集-DataLoader解析以及一个通用的数据集加载模板

DataLoader

class torch.utils.data.DataLoader( \ 
dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, \
num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0,\
worker_init_fn=None, multiprocessing_context=None, generator=None, *, \
prefetch_factor=2, persistent_workers=False)

Data loader 结合一个dataset和一个采样器,并且提供一个给定数据集(dataset)上的迭代器

  • (原文描述:Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.)

DataLoader 支持map类型和iterable类型的数据集,可以使用单线程或多线程加载数据集。使用自定义得加载方式以及 =可选的自动批处理(排序规则)和内存固定规则

  • (原文描述:The DataLoader supports both map-style and iterable-style datasets with single- or multi-process loading, customizing loading order and optional automatic batching (collation) and memory pinning.)

参数及其含义

  • dataset (Dataset) – dataset from which to load the data.

  • batch_size (int, optional) – how many samples per batch to load (default: 1).

  • shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).

  • sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with len implemented. If specified, shuffle must not be specified.

  • batch_sampler (Sampler or Iterable, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.

  • num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)

  • collate_fn (callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.

  • pin_memory (bool, optional) – If True, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.

  • drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)

  • timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)

  • worker_init_fn (callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

  • prefetch_factor (int, optional, keyword-only arg) – Number of samples loaded in advance by each worker. 2 means there will be a total of 2 * num_workers samples prefetched across all workers. (default: 2)

  • persistent_workers (bool, optional) – If True, the data loader will not shutdown the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False)

通用的数据集加载模板

首先构建一个提供Dataset的类:
import torch.utils.data as Data
class MyDataSet(Data.Dataset):
  def __init__(self, x_train, x_label, y_train, y_label): 
    super(MyDataSet, self).__init__()
    self.x_train = x_train
    self.x_label = x_label
    self.y_train = y_train
    self.y_label = y_label
  
  def __len__(self):
    return self.x_train.shape[0]
  
  def __getitem__(self, idx):
    return self.x_train[idx], self.x_label[idx], self.y_train[idx], self.y_label[idx]
  • 继承torch.utils.data.Dataset,构造一个自己的MyDataSet类
  • 构造__init__函数,传入所有的数据和标签
  • 构造__len__函数,输入一个代表数据集长度的量
  • 构造__getitem__函数,根据传入的idx,输出对应索引的数据和标签
然后使用torch.utils.data.DataLoader构造自己的DataLoader
MyDataLoader = Data.DataLoader(MyDataSet(x_train, x_label, \
					y_train, y_label), batch_size=2, shuffle=True)
然后就可以以for循环读取每一个batch:
for epoch in range(30):
    for x_train, x_label, y_train, y_label in loader:
    '''
      x_train: [batch_size, src_len]
      x_label: [batch_size, src_len]
      y_train: [batch_size, src_len]
      y_label: [batch_size, src_len]
    '''

这里读取到的数据的维度是:

      x_train: [batch_size, src_len]
      x_label: [batch_size, src_len]
      y_train: [batch_size, src_len]
      y_label: [batch_size, src_len]

再举一个更加灵活的例子

from data_pro import load_data_and_labels, Data
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

x_train, x_test, label_train, label_test = load_data_and_labels('***.pos', '***.neg')
  • 这里的数据集是从文件中读取的,并且是list类型的
  • 也就是说x_train是一个list:[[**], [**], [**], [**]]
  • 同样的,label_train也是一个list:[*, *, *, *]
  • 这里的特征向量是一个List,labl是一个值,相当于x_train中的一个元素(列表)对应label_train中的一个元素(值)

然后我们可以将其转化为torch.utils.data.Dataset类型

import torch.utils.data.Dataset as Data
train_data = Data(x_train, label_train)

这里我们没有指定Dataset的__getitem__()函数,所以我们需要指定一个collate_fn,来实现相同的功能:

def collate_fn(batch):
    data, label = zip(*batch)
    return data, label

collate_fn函数,我们可以理解其为一个transformer,对于每一个batch,返回其第二个维度开始分割的两部分
即:

  • 当前的batch是[batch_size, x_size]以及[batch_size, label_size]

之后,定义train_loader

train_loader = DataLoader(train_data, batch_size=5, shuffle=True, collate_fn=collate_fn)

我们可以分析一下在for循环中,我们是怎么从这个train_loader中获取数据的:

import torch.utils.data as Data
import numpy as np

x_train = np.array([[1,2,3],
                      [4,5,6]])
label_train = np.array([0,1])


class MyDataset(Data.Dataset):
    def __init__(self, x_train, label_train):
        super(MyDataset, self).__init__()
        self.x_train = x_train
        self.label_train = label_train
    
    def __len__(self):
        return self.x_train.shape[0]
    
    def __getitem__(self, idx):
        return self.x_train[idx], self.label_train[idx]

def collate_fn(batch):
    data, label = zip(*batch)
    return data, label

MyDataLoader = Data.DataLoader(MyDataset(x_train, label_train),\
                              batch_size=1, shuffle=True)

for train_item in MyDataLoader:
        print(f"type of train_item is {type(train_item)}")
        print(train_item)

output:

type of train_item is <class 'list'>
[tensor([[1, 2, 3]], dtype=torch.int32), tensor([0], dtype=torch.int32)]
type of train_item is <class 'list'>
[tensor([[4, 5, 6]], dtype=torch.int32), tensor([1], dtype=torch.int32)]

如果我们将collate_fn 设置好,那么返回的结果是:

MyDataLoader = Data.DataLoader(MyDataset(x_train, label_train),\
                              batch_size=1, shuffle=True, collate_fn=collate_fn)

输出:

type of train_item is <class 'tuple'>
((array([4, 5, 6]),), (1,))
type of train_item is <class 'tuple'>
((array([1, 2, 3]),), (0,))

记得把array转换为Tensor噢

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
PyTorch加载数据集通常有两种常见的方法:使用自定义数据集和使用预定义数据集。 1. 使用自定义数据集: - 创建一个新的Python类,继承`torch.utils.data.Dataset`,并实现`__len__`和`__getitem__`方法。`__len__`返回数据集的大小,`__getitem__` 根据给定索引返回样本。 - 在`__init__`方法中,根据需求加载数据集并对其进行预处理。 - 可以使用PyTorch提供的各种数据转换方法(例如`torchvision.transforms`)来对数据进行预处理。 - 在训练代码中,实例化自定义数据集类,并使用`torch.utils.data.DataLoader`将数据加载到训练循环中。 下面是一个简单的自定义数据集加载示例: ```python import torch from torch.utils.data import Dataset, DataLoader class CustomDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data[idx] # 在这里进行数据预处理 return torch.Tensor(sample) # 假设有一个包含样本的列表 data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] dataset = CustomDataset(data) dataloader = DataLoader(dataset, batch_size=2, shuffle=True) for batch in dataloader: # 在这里执行训练循环 print(batch) ``` 2. 使用预定义数据集: - PyTorch提供了一些预定义的数据集,如`torchvision.datasets`模块中的MNIST、CIFAR10等。 - 可以使用预定义数据集的构造函数来加载数据集,并根据需要进行转换和预处理。 - 同样,可以使用`torch.utils.data.DataLoader`将数据加载到训练循环中。 下面是一个预定义数据集加载示例: ```python import torch import torchvision from torchvision import transforms # 定义数据转换和预处理 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) # 加载MNIST数据集 train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True) test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform) # 使用DataLoader加载数据集 train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False) for batch in train_dataloader: # 在这里执行训练循环 images, labels = batch print(images.shape, labels.shape) ``` 这些是基本的加载数据集的方法,你可以根据自己的需求进行修改和扩展。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值