Pytorch深度学习-DataLoader的使用(小土堆)

  1. API
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device='')
  1. 注释解析
1. dataset (Dataset) – 要从中加载数据的数据集。
2. batch_size (int, 可选) – 每批要加载的样本数 (默认值: `1` )。
3. shuffle (bool, 可选) – 设置为 `True` 在每次抓取数据时重新洗牌数据(默认值:`False`)。 
4. sampler(Sampler 或 Iterable,可选) – 定义从数据集中提取样本的策略。可以是任何 `Iterable` 与 `__len__` 实现。如果指定, `shuffle` 则不得指定。
5. batch_sampler (Sampler or Iterable, 可选) – 像 `sampler` ,但一次返回一批索引。与 `batch_size` 、 `shuffle` 、 `sampler` 和 `drop_last` 互斥。
6. num_workers (int, 可选) – 用于数据加载的子进程数。 `0` 表示数据将在主进程中加载。(默认值: `0` )
7. collate_fn (Callable, 可选) – 合并样本列表以形成一小批 Tensor。在使用地图样式数据集的批量加载时使用。
8. pin_memory (bool, 可选) - 如果 `True` ,数据加载器会在返回张量之前将张量复制到设备/CUDA 固定内存中。如果数据元素是自定义类型,或者 `collate_fn` 返回的批处理是自定义类型,请参阅下面的示例。
9. drop_last (bool, 可选) – `True` 如果数据集大小不能被批大小整除,即在倒数第二次抓取时,剩下的数据不能满足步长,则删除最后一个未完成的批处理。如果 `False` 数据集的大小不能被批处理大小整除,即在倒数第二次抓取时,剩下的数据不能满足步长,则直接抓取剩下的所有数据。(默认值: `False` )
10. timeout (numeric, 可选) – 如果为正数,则为从工作线程收集批处理的超时值。应始终为非负数。(默认值: `0` )
11. worker_init_fn (Callable, 可选) – 如果不是 `None` ,则在种子设定之后和数据加载之前,将在每个工作线程子进程上调用此操作器,并将工作线程 ID(int in `[0, num_workers - 1]` )作为输入。(默认值: `None` )
12. multiprocessing_context (str 或 multiprocessing.context.BaseContext,可选) - 如果 `None` ,将使用操作系统的默认多处理上下文。(默认值: `None` )
13. generator(torch.Generator, 可选) – 如果不是 `None` ,RandomSampler 将使用此 RNG 生成随机索引和 `base_seed` 多处理以生成 worker。(默认值: `None` )
14. prefetch_factor (int, 可选, keyword-only arg) - 每个工作线程预先加载的批次数。 `2` 意味着在所有工作线程中总共预取 2 * num_workers 批次。(默认值取决于num_workers的设置值。如果值 num_workers=0,则默认值为 `None` 。否则,如果 default 的 `num_workers > 0` 值为 `2` )。
15. persistent_workers (bool, 可选) - 如果 `True` ,数据加载程序不会在数据集被使用一次后关闭工作进程。这允许保持工作线程数据集实例处于活动状态。(默认值: `False` )
16. pin_memory_device (str, 可选) – 设备to `pin_memory` if `pin_memory` 为 `True` 。
  1. 代码示例
'''
测试数据为:
torch.Size([4, 3, 32, 32])
tensor([4, 1, 6, 7])
四个数据,三个通道,大小为32*32
target分别属于4,1,6,7
'''
import torchvision  
from torch.utils.data import DataLoader  
  
#准备测试数据集  
test_data = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)  
#实例化DataLoader  
#数据集为test_data,  
# 步长为4  
# 每次抓取都打乱  
# 在主进程进行运行  
# 最后一次数据不足时,抓取全部  
test_loader = DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)  
  
#测试数据集中的第一涨图片及target  
img,target = test_data[0]  
print(img.shape)  
print(target)  
  
for data in test_loader:  
    imgs,targets = data  
    print(imgs.shape)  
    print(targets)
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值