【PyTorch】使用DataLoader自定义数据集读取

【PyTorch】使用DataLoader自定义数据集读取

为了方便之后使用PyTorch的distributed部署,加速训练,将数据读取的方式改为适配pytorch提供的Dataset和DataLoader的方式。这里记录一下修改的要点:

1. 涉及的import库:

import torch
from torch.utils.data import Dataset, DataLoader

2. 自定义一个Dataset类:

  • 该类继承Dataset;

  • 可以定义若干个数据预处理的函数,关键的两个函数是:__len__()__getitem__();

  • __getitem__()实际是python支持的一个迭代器函数,编写时每次返回一个sample,不需要定义batch size,之后的DataLoader会自动帮忙读取数据组成batch的;

  • 举个栗子:

    class MyDataset(Dataset):
    	def __init__(self,data):
    		self.data = data
    	def __len__(self):
    		return len(self.data)
    	def __getitem__(self):
    		return self.data
    	def output(self):
    		print('output')
    

3. 初始化Dataset和DataLoader类:

  • DataLoader的参数可参考:https://blog.csdn.net/zyq12345678/article/details/90268668

  • 注意,如果在Dataset中每次返回的是自己定义的数据类型,或者是字典类型,有时要自己编写collate_fn()函数,告诉系统如何返回一个batch。

  • 举个栗子:

    dataset = MyDataset(data)
    dataloader = DataLoader(
        dataset,
        batch_size = 2,
        num_workers = 8,
        collate_fn = collate_fn,
        pin_memory = True
    )
    # 返回数据结构较复杂,包括自定义数据类型或字典时
    def collate_fn(batch):
        data = list(batch)
        return (data)
    
  • 如果遇到类似报错:

    TypeError: can't pickle _thread._local objects

    请将DataLoader中的num_workers参数设置为0,关闭多线程。原因可能是无法自动多线程处理复杂的数据类型。

4. 访问Dataloader内的Dataset类函数

  • 举个栗子:
for step, batch in enumerate(dataloader):
	dataloader.dataset.output()
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值