问题
最近用pytorch做实验时,遇到加载大量数据的问题。实验数据大小在400Gb,而本身机器的memory只有256Gb,显然无法将数据一次全部load到memory。
解决方法
首先自定义一个MyDataset继承torch.utils.data.Dataset,然后将MyDataset的对象feed in torch.utils.data.DataLoader()即可。MyDataset在__init__中声明一个文件对象,然后在__getitem__中缓慢读取数据,这样就不会一次把所有数据加载到内存中了。训练数据存放在train.txt中,每一行是一条数据记录。
import torch.utils.data as Data
from tqdm import tqdm
class MyDataset(Data.Dataset):
def __init__(self,filepath):
number = 0
with open(filepath,"r") as f:
# 获得训练数据的总行数
for _ in tqdm(f,desc="load training dataset"):
number+=1
self.number = number
self.fopen = open(filepath,'r')
def __len__(self):
return self.number
def __getitem__(self,index):
line = self.fopen.__next__()
# 自定义transform()对训练数据进行预处理
data = transform(line)
return data
train_dataset = MyDataset(filepath = "train.txt")
training_data = Data.DataLoader(dataset=train_dataset, batch_size=32,num_workers=1)
注意
- num_workers只能设置为1。因为MyDataset初始化时只有一个文件对象,在dataloader时num_workers=1只用一个线程去操作文件对象读取数据。如果num_workers>1, 会出错,多个线程同时操作同一个文件对象,得到的数据并不是你想要的。
- 每一个epoch结束以后,需要重新声明train_dataset和training_data。因为一个epoch结束以后,文件对象已经指向文件末尾,下一个epoch取数据时,什么也得不到。
- 因为这里__getitem__()只是顺序的从文件中取出一行,而与index无关,那么在DataLoader时,即使参数shuffle指定为True,得到的数据依然是顺序的,即该方法无法shuffle数据。