【pytorch】torch.utils.data.Dataset & torch.utils.data.Dataloader

1. torch.utils.data.Dataset

torch.utils.data.Dataset 是一个抽象类,所以使用时需要自己创建一个子类实现接口。
需要实现什么呢?所有子类需要实现__getitem__(self, index),其作用是根据index获取对应的data;另外,可选择的实现__len__(),作用是返回数据集大小。

python魔法方法(Magic methods)。魔术方法在类或对象的某些事件出发后会自动执行,以两个下划线开头、两个下划线结尾的方法,如果希望根据自己的程序定制特殊功能的类,那么就需要对这些方法进行重写。下面简答介绍几种:
__getitem__(self, key) , 定义获取容器中指定元素的行为,相当于 self[key]
__len__(self), 定义当被 len() 调用时的行为


  • 如果你希望定制的容器是不可变的话,你只需要定义__len__()__getitem__()这两个魔法方法。
  • 如果你希望定制的容器是可变的话,那你除了定义__len__()__getitem__()方法外,还需要定义__setitem__() (定义设置容器中指定元素的行为,相当于self[key] = value)和 __delitem__()(定义删除容器中指定元素的行为,相当于del self[key])两个方法.
import torch
import numpy as np

class subDataset(torch.utils.data.Dataset):
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    def __len__(self):
        return len(self.Data)
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.Tensor(self.Label[index])
        return data, label

Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8], [9, 10], [11, 12]])
Label = np.asarray([[0], [1], [0], [2], [3], [4]])
# 创建数据集
dataset = subDataset(Data, Label)
print(dataset) # <__main__.subDataset object at 0x7f87a57aa128>
# 获取数据集长度信息
print(dataset.__len__()) # 6
# 获取数据
print(dataset.__getitem__(0)) # (tensor([1., 2.]), tensor([0.]))
print(dataset[0]) # (tensor([1., 2.]), tensor([0.]))

2. torch.utils.data.Dataloader

torch.utils.data.Dataloader 包括 datatset 和 sampler, 提供一个iterable。

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False)

  • dataset (Dataset) – dataset from which to load the data. 数据集
  • batch_size (int, optional) – how many samples per batch to load (default: 1).
  • shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).
  • sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with len implemented. If specified, shuffle must not be specified.
  • batch_sampler (Sampler or Iterable, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
    与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)
  • num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
    这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)
  • collate_fn (callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
  • pin_memory (bool, optional)– If True, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.
    如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.
  • drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)
  • timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)
  • worker_init_fn (callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)
  • prefetch_factor (int, optional, keyword-only arg) – Number of samples loaded in advance by each worker. 2 means there will be a total of 2 * num_workers samples prefetched across all workers. (default: 2)
  • persistent_workers (bool, optional) – If True, the data loader will not shutdown the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False)每个worker初始化函数

什么是map-style和iterable-style? A map-style dataset in Pytorch has the getitem() and len() and iterable-style datasets has iter() protocol.
参考: https://stackoverflow.com/questions/63347149/pytorch-dataset-map-style-vs-iterable-style

dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=4)
for i, item in enumerate(dataloader):
    print(i, item)

参数 collate_fn (callable, optional)

collate_fn (callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.


collate_fn=lambda x:(
  [x[i][j].unsqueeze(0) for i in range(len(x))], 0
  ) for j in range(len(x[0]))


dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, shuffle=False, num_workers=4)
for i, item in enumerate(dataloader):
    print(i, item)
# 0 [tensor([[1., 2.],[3., 4.],[5., 6.]]), tensor([[0.],[1.],[0.]])]
# 1 [tensor([[ 7.,  8.],[ 9., 10.],[11., 12.]]), tensor([[2.],[3.],[4.]])]


dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, 
                                         collate_fn=lambda x:(torch.cat([x[i][j].unsqueeze(0) for i in range(len(x))], 0) for j in range(len(x[0])) ) )
for i, item in enumerate(dataloader):
    print(i, list(item))

# 0 [tensor([[1., 2.],[3., 4.],[5., 6.]]), tensor([[0.],[1.],[0.]])]
# 1 [tensor([[ 7.,  8.],[ 9., 10.],[11., 12.]]), tensor([[2.],[3.],[4.]])]

collate_fn 设置为 lambda x:x ,看看传入的是什么:
传入一个长度为batch_size的list,每个位置上存储着(data, label)二元tuple

dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, collate_fn=lambda x:x)
for i, item in enumerate(dataloader):
    print(i, item)
# 0 [(tensor([1., 2.]), tensor([0.])), (tensor([3., 4.]), tensor([1.])), (tensor([5., 6.]),tensor([0.]))]
# 1 [(tensor([7., 8.]), tensor([2.])), (tensor([ 9., 10.]), tensor([3.])), (tensor([11., 12.]), tensor([4.]))]

Dataset 每次负责索引输出一个(features)元组 = f_num*f_len, 装入DataLoader后,索引batch_size 个 (fearures) 到 collate_fn,再输出 f_num * batch_size * f_len





