【pytorch】torch.utils.data.Dataset & torch.utils.data.Dataloader

1. torch.utils.data.Dataset

torch.utils.data.Dataset 是一个抽象类,所以使用时需要自己创建一个子类实现接口。
需要实现什么呢?所有子类需要实现__getitem__(self, index),其作用是根据index获取对应的data;另外,可选择的实现__len__(),作用是返回数据集大小。

python魔法方法(Magic methods)。魔术方法在类或对象的某些事件出发后会自动执行,以两个下划线开头、两个下划线结尾的方法,如果希望根据自己的程序定制特殊功能的类,那么就需要对这些方法进行重写。下面简答介绍几种:
__getitem__(self, key) , 定义获取容器中指定元素的行为,相当于 self[key]
__len__(self), 定义当被 len() 调用时的行为

python容器。在python中,像序列类型(如列表,元组和字符串)或映射类型(如字典)都属于容器类型。

  • 如果你希望定制的容器是不可变的话,你只需要定义__len__()__getitem__()这两个魔法方法。
  • 如果你希望定制的容器是可变的话,那你除了定义__len__()__getitem__()方法外,还需要定义__setitem__() (定义设置容器中指定元素的行为,相当于self[key] = value)和 __delitem__()(定义删除容器中指定元素的行为,相当于del self[key])两个方法.
import torch
import numpy as np

#创建子类
class subDataset(torch.utils.data.Dataset):
    #初始化,定义数据内容和标签
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    #返回数据集大小
    def __len__(self):
        return len(self.Data)
    #得到数据内容和标签
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.Tensor(self.Label[index])
        return data, label

Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8], [9, 10], [11, 12]])
Label = np.asarray([[0], [1], [0], [2], [3], [4]])
# 创建数据集
dataset = subDataset(Data, Label)
print(dataset) # <__main__.subDataset object at 0x7f87a57aa128>
# 获取数据集长度信息
print(dataset.__len__()) # 6
# 获取数据
print(dataset.__getitem__(0)) # (tensor([1., 2.]), tensor([0.]))
print(dataset[0]) # (tensor([1., 2.]), tensor([0.]))



2. torch.utils.data.Dataloader

torch.utils.data.Dataloader 包括 datatset 和 sampler, 提供一个iterable。

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False)
详细参数:

  • dataset (Dataset) – dataset from which to load the data. 数据集
  • batch_size (int, optional) – how many samples per batch to load (default: 1).
  • shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).
  • sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with len implemented. If specified, shuffle must not be specified.
    自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False
  • batch_sampler (Sampler or Iterable, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
    与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)
  • num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
    这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)
  • collate_fn (callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
    默认的collate_fn是将一个list的samples组成一个mini-batch的函数。也就是说传入的是一个list的(data,label)
  • pin_memory (bool, optional)– If True, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.
    如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.
  • drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)
    如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了。如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。
  • timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)
  • worker_init_fn (callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)
  • prefetch_factor (int, optional, keyword-only arg) – Number of samples loaded in advance by each worker. 2 means there will be a total of 2 * num_workers samples prefetched across all workers. (default: 2)
    如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0
  • persistent_workers (bool, optional) – If True, the data loader will not shutdown the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False)每个worker初始化函数

什么是map-style和iterable-style? A map-style dataset in Pytorch has the getitem() and len() and iterable-style datasets has iter() protocol.
参考: https://stackoverflow.com/questions/63347149/pytorch-dataset-map-style-vs-iterable-style

dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=4)
for i, item in enumerate(dataloader):
    print(i, item)

参数 collate_fn (callable, optional)

collate_fn (callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.

所以,collate_fn是组成mini-batch的必要操作,即:

collate_fn=lambda x:(
 torch.cat(
  [x[i][j].unsqueeze(0) for i in range(len(x))], 0
  ) for j in range(len(x[0]))
 )

先看使用默认的collate_fn的输出:

dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, shuffle=False, num_workers=4)
for i, item in enumerate(dataloader):
    print(i, item)
# 0 [tensor([[1., 2.],[3., 4.],[5., 6.]]), tensor([[0.],[1.],[0.]])]
# 1 [tensor([[ 7.,  8.],[ 9., 10.],[11., 12.]]), tensor([[2.],[3.],[4.]])]

替换collate_fn后:

dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, 
                                         collate_fn=lambda x:(torch.cat([x[i][j].unsqueeze(0) for i in range(len(x))], 0) for j in range(len(x[0])) ) )
for i, item in enumerate(dataloader):
    print(i, list(item))

# 0 [tensor([[1., 2.],[3., 4.],[5., 6.]]), tensor([[0.],[1.],[0.]])]
# 1 [tensor([[ 7.,  8.],[ 9., 10.],[11., 12.]]), tensor([[2.],[3.],[4.]])]

collate_fn 设置为 lambda x:x ,看看传入的是什么:
传入一个长度为batch_size的list,每个位置上存储着(data, label)二元tuple

dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, collate_fn=lambda x:x)
for i, item in enumerate(dataloader):
    print(i, item)
# 0 [(tensor([1., 2.]), tensor([0.])), (tensor([3., 4.]), tensor([1.])), (tensor([5., 6.]),tensor([0.]))]
# 1 [(tensor([7., 8.]), tensor([2.])), (tensor([ 9., 10.]), tensor([3.])), (tensor([11., 12.]), tensor([4.]))]

Dataset 每次负责索引输出一个(features)元组 = f_num*f_len, 装入DataLoader后,索引batch_size 个 (fearures) 到 collate_fn,再输出 f_num * batch_size * f_len

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值