【PyTorch】使用DataLoader自定义数据集读取
为了方便之后使用PyTorch的distributed部署,加速训练,将数据读取的方式改为适配pytorch提供的Dataset和DataLoader的方式。这里记录一下修改的要点:
1. 涉及的import库:
import torch
from torch.utils.data import Dataset, DataLoader
2. 自定义一个Dataset类:
-
该类继承Dataset;
-
可以定义若干个数据预处理的函数,关键的两个函数是:
__len__()
和__getitem__()
; -
__getitem__()
实际是python支持的一个迭代器函数,编写时每次返回一个sample,不需要定义batch size,之后的DataLoader会自动帮忙读取数据组成batch的; -
举个栗子:
class MyDataset(Dataset): def __init__(self,data): self.data = data def __len__(self): return len(self.data) def __getitem__(self): return self.data def output(self): print('output')
3. 初始化Dataset和DataLoader类:
-
DataLoader的参数可参考:https://blog.csdn.net/zyq12345678/article/details/90268668
-
注意,如果在Dataset中每次返回的是自己定义的数据类型,或者是字典类型,有时要自己编写
collate_fn()
函数,告诉系统如何返回一个batch。 -
举个栗子:
dataset = MyDataset(data) dataloader = DataLoader( dataset, batch_size = 2, num_workers = 8, collate_fn = collate_fn, pin_memory = True ) # 返回数据结构较复杂,包括自定义数据类型或字典时 def collate_fn(batch): data = list(batch) return (data)
-
如果遇到类似报错:
TypeError: can't pickle _thread._local objects
请将DataLoader中的
num_workers
参数设置为0,关闭多线程。原因可能是无法自动多线程处理复杂的数据类型。
4. 访问Dataloader内的Dataset类函数
- 举个栗子:
for step, batch in enumerate(dataloader):
dataloader.dataset.output()