DataLoader PyTorch 主要参数的含义

本文详细解释了PyTorch中的DataLoader类,介绍了其在数据集加载中的作用,以及trainset和testset参数的设置。重点讨论了prefetch_factor、batch_size和pin_memory对内存占用的影响。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

定义:

DataLoader类是一个用于从数据集(dataset)中加载数据,并以迭代器(iterator)的形式返回数据样本(data samples)的工具¹²。您给出的两个字典(dictionary)分别是训练集(train set)和测试集(test set)的数据加载参数,下面我会逐一解释它们的含义和默认值:

 

举例演示:

假设prefetch_factor=2,batch_size=4,persistent_workers=True,你的回答里面的每个worker每次获取多少数据呢?

假设您有一个名为train_dataset的数据集,它包含了100个数据样本,每个样本由一个特征向量和一个标签组成。您可以使用以下代码创建一个使用prefetch_factor=2,batch_size=4,persistent_workers=True的DataLoader对象:

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=4, num_workers=8, prefetch_factor=2, persistent_workers=True)

 

 小提示:

pin_memory将数据样本保存在锁页内存,这里的内存指的是cpu的内存还是gpu的内存?

pin_memory参数是一个布尔值,表示是否将数据样本保存在锁页内存(pinned memory)中。这里的内存指的是CPU的内存,而不是GPU的内存12

 降低内存占用,推到的时候用到:

testloader_params = {
          'batch_size': 1,
          'shuffle': False,
          'num_workers': 0,
          'pin_memory': True,
          # 'prefetch_factor': 4,
          # 'persistent_workers': True,
          'drop_last': False
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

子燕若水

吹个大气球

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值