Pytorch IterableDataset的使用
背景
当数据量特别大,无法一次性load进内存时,Pytorch里的Dataset就无法胜任了,此时需要使用IterableDataset.
基本用法
只需要实现__init__()
、__iter__()
和__len__()
,模版如下:
from torch.utils.data import IterableDataset, DataLoader
class MyIterableDataset(IterableDataset):
def __init__(self):
# 实现初始化代码
pass
def __iter__(self):
# 返回一个数据的迭代器
pass
def __len__(self):
# 返回数据长度
pass
mydataset = MyIterableDataset() # 可迭代对象
mydataloader = DataLoader(mydataset, shuffle=False, batch_size=batch_size, num_workers=num_workers) # shuffle必须要是False
一个例子
读取CNN_Dailymail摘要数据集
class SummaryDataset(IterableDataset):
def __init__(self,
file_path: str
):
super(SummaryDataset).__init__()
self.file_path = file_path
self.info = self._get_file_info(file_path)
self.start = self.info['start']
self.end = self.info['end']
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # single worker
iter_start = self.start
iter_end = self.end
else