pytorch中DataLoader详解

功能初体验


import torch
import torch.utils.data as Data

if __name__ == '__main__':
    torch.manual_seed(1)  # reproducible

    BATCH_SIZE = 5  # 批训练的数据个数

    x = torch.linspace(11, 20, 10)  # x data: tensor([11., 12., 13., 14., 15., 16., 17., 18., 19., 20.])
    y = torch.linspace(1, 10, 10)  # y data: tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.])
    # print(x)
    # print(y)
    # 先转换成 torch 能识别的 Dataset
    torch_dataset = Data.TensorDataset(x, y) #[return:
                                                        # (tensor(x1),tensor(y1));
                                                        # (tensor(x2),tensor(y2));
                                                        # ......
    # print(torch_dataset, list(torch_dataset))


    # 把 dataset 放入 DataLoader
    loader = Data.DataLoader(
        dataset=torch_dataset,  # torch TensorDataset format
        batch_size=BATCH_SIZE,  # mini batch size
        shuffle=False,  # 要不要打乱数据 (打乱比较好)
        num_workers=0,  # 多线程来读数据
    )

    for epoch in range(3):  # 训练所有!整套!数据 3 次
        for step,(batch_x,batch_y) in enumerate(loader):  # 每一步 loader 释放一小批数据用来学习
                                    #return:
                                            #(tensor(x1,x2,x3,x4,x5),tensor(y1,y2,y3,y4,y5))
                                            #(tensor(x6,x7,x8,x9,x10),tensor(y6,y7,y8,y9,y10))
            # 假设这里就是你训练的地方...

            # 打出来一些数据
            print('Epoch: ', epoch, '| Step:', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy())

在这里插入图片描述

参数简介

在这里插入图片描述
上图为源码中dataloader中所有的可选参数。除了第一个dataset参数外,其他均为可选参数。

  • Dataset:处理好的所有数据
  • batch_size:批数量
  • shuffle:打乱数据
  • sampler:采样机制,即从数据集里面取样本的方式(迭代器,每次返回一个样本)
  • batch_sampler:把sampler的采样的样本根据batch_size组织成一个batch返回
  • num_worker:加载数据的线程数
  • collate_fn:把batch_sampler返回的list结构的一个batch的样本打包成一个tensor的结构
  • pin_memory:将加载的数据拷贝到CUDA中的固定内存中,从而使数据更快地传输到支持cuda的gpu
  • drop_last:丢弃余数
  • timeout:如果是正数,表明等待从加载一个batch等待的时间,若超出设定的时间还没有加载完,就放弃这个batch,如果是0,表示不设置限制时间。默认为0
  • worker_init_fn:如果不是None ,它将在每个worker子进程上以worker id ([0, num_workers - 1] )作为输入调用,在seeding之后和数据加载之前。
  • generater:如果不是None,这个RNG将被RandomSampler用来生成随机索引,并被multiprocessing用来为worker生成’ base_seed ‘。 (默认值:’ ‘没有’ ')
  • prefetch_factor:提前加载多少个batch的数据,可以保证线程不会等待,每个线程都总有至少一个数据在加载。提升显卡利用率。
  • persistent_workers:如果为True,数据加载器将不会在数据集运行完一个Epoch后关闭worker进程。这允许维护worker数据集实例保持激活。(默认值:False),意思是运行完一个Epoch后并不会关闭worker进程,而是保持现有的worker进程继续进行下一个Epoch的数据加载。好处是Epoch之间不必重复关闭启动worker进程,加快训练速度。

Dataloader参数之间的互斥

值得注意的是,Dataloader的参数之间存在互斥的情况,主要针对自己定义的采样器:

  • sampler:如果自行指定了sampler参数,则shuffle必须保持默认值,即False
  • batch_sampler:如果自行指定了batch_sampler参数,则 batch_size、shuffle、sampler、drop_last 都必须保持默认值
  • 如果没有指定自己是采样器,那么默认的情况下(即sampler和batch_sampler均为None的情况下),dataloader的采样策略是如何的呢:

sampler:

  • shuffle = True:sampler采用 RandomSampler,即随机采样
  • shuffle = Flase:sampler采用 SequentialSampler,即按照顺序采样
  • batch_sampler:采用 BatchSampler,即根据 batch_size 进行batch采样
  • 上面提到的 RandomSampler、SequentialSampler和BatchSampler都是PyTorch自己实现的,且它们都是Sampler的子类。

Sampler

SequentialSampler

SequentialSampler就是一个按照顺序进行采样的采样器,接收一个数据集做参数(实际上任何可迭代对象都可),按照顺序对其进行采样:

from torch.utils.data import SequentialSampler

pseudo_dataset = list(range(10, 20))
for data in SequentialSampler(pseudo_dataset):
    print(data, end=" ")
0 1 2 3 4 5 6 7 8 9 

RandomSampler

RandomSampler 即一个随机采样器,返回随机采样的值,第一个参数依然是一个数据集(或可迭代对象)。还有一组参数如下:

  • replacement:bool值,默认是False,设置为True时表示可以采出重复的样本
  • num_samples:只有在replacement设置为True的时候才能设置此参数,表示要采出样本的个数,默认为数据集的总长度。有时候由于replacement置True的原因导致重复数据被采样,导致有些数据被采不到,所以往往会设置一个比较大的值
from torch.utils.data import RandomSampler

pseudo_dataset = list(range(10, 20))

randomSampler1 = RandomSampler(pseudo_dataset)
randomSampler2 = RandomSampler(pseudo_dataset, replacement=True, num_samples=20)

print("for random sampler #1: ")
for data in randomSampler1:
    print(data, end=" ")

print("\n\nfor random sampler #2: ")
for data in randomSampler2:
    print(data, end=" ")

for random sampler #1: 
4 5 2 9 3 0 6 8 7 1 

for random sampler #2: 
4 9 0 6 9 3 1 6 1 8 5 0 2 7 2 8 6 4 0 6 

WeightedRandomSampler

WeightedRandomSampler和RandomSampler的参数一致,但是不在传入一个dataset,第一个参数变成了weights,只接收一个一定长度的list作为 weights 参数,表示采样的权重,采样时会根据权重随机从 list(range(len(weights))) 中采样,即WeightedRandomSampler并不需要传入样本集,而是只在一个根据weights长度创建的数组中采样,所以采样的结果可能需要进一步处理才能使用。weights的所有元素之和不需要为1。

from torch.utils.data import WeightedRandomSampler

weights = [1,1,10,10]

weightedRandomSampler = WeightedRandomSampler(weights, replacement=True, num_samples=20)

for data in weightedRandomSampler:
    print(data, end=" ")
2 2 2 3 2 2 3 2 3 3 1 3 2 2 1 3 3 2 3 3 

详细使用可参考: WeightedRandomSampler使用案例

BatchSampler

其他Sampler在每次迭代都只返回一个索引,而BatchSampler的作用是对上述这类返回一个索引的采样器进行包装,按照设定的batch size返回一组具体数据,因其他的参数和上述的有些不同:

  • sampler:一个Sampler对象(或者一个可迭代对象)
  • batch_size:batch的大小
  • drop_last:是否丢弃最后一个可能不足batch size大小的数据
from torch.utils.data import BatchSampler
pseudo_dataset = list(range(10, 20))

batchSampler1 = BatchSampler(pseudo_dataset, batch_size=3, drop_last=False)
batchSampler2 = BatchSampler(pseudo_dataset, batch_size=3, drop_last=True)

print("for batch sampler #1: ")
for data in batchSampler1:
    print(data, end=" ")

print("\n\nfor batch sampler #2: ")
for data in batchSampler2:
    print(data, end=" ")
for batch sampler #1: 
[10, 11, 12] [13, 14, 15] [16, 17, 18] [19] 

for batch sampler #2: 
[10, 11, 12] [13, 14, 15] [16, 17, 18] 

SubsetRandomSampler

SubsetRandomSampler 可以设置子集的随机采样,多用于将数据集分成多个集合,比如训练集和验证集的时候使用:

from torch.utils.data import SubsetRandomSampler

pseudo_dataset = list(range(10, 20))

subRandomSampler1 = SubsetRandomSampler(pseudo_dataset[:7])
subRandomSampler2 = SubsetRandomSampler(pseudo_dataset[7:])

print("for subset random sampler #1: ")
for data in subRandomSampler1:
    print(data, end=" ")

print("\n\nfor subset random sampler #2: ")
for data in subRandomSampler2:
    print(data, end=" ")

for subset random sampler #1: 
14 15 11 16 13 10 12 

for subset random sampler #2: 
17 19 18 

参考:https://blog.csdn.net/qq_38962621/article/details/111146427

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Shashank497

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值