Pytorch IterableDataset的使用
背景
当数据量特别大,无法一次性load进内存时,Pytorch里的Dataset就无法胜任了,此时需要使用IterableDataset.
基本用法
只需要实现__init__()
、__iter__()
和__len__()
,模版如下:
from torch.utils.data import IterableDataset, DataLoader
class MyIterableDataset(IterableDataset):
def __init__(self):
# 实现初始化代码
pass
def __iter__(self):
# 返回一个数据的迭代器
pass
def __len__(self):
# 返回数据长度
pass
mydataset = MyIterableDataset() # 可迭代对象
mydataloader = DataLoader(mydataset, shuffle=False, batch_size=batch_size, num_workers=num_workers) # shuffle必须要是False
一个例子
读取CNN_Dailymail摘要数据集
class SummaryDataset(IterableDataset):
def __init__(self,
file_path: str
):
super(SummaryDataset).__init__()
self.file_path = file_path
self.info = self._get_file_info(file_path)
self.start = self.info['start']
self.end = self.info['end']
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # single worker
iter_start = self.start
iter_end = self.end
else: # multiple workers
per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
sample_iterator = self._sample_generator(iter_start, iter_end)
return sample_iterator
def __len__(self):
return self.end - self.start
def _get_file_info(self,
file_path
):
info = {
"start": 1,
"end": 0,
"id_colum": 0,
"article_colum": 1,
"summary_colum": 2
}
with open(file_path, 'r') as fin:
for _ in enumerate(fin):
info['end'] += 1
return info
def _sample_generator(self, start, end):
id_c, art_c, sum_c = self.info['id_colum'], self.info['article_colum'], self.info['summary_colum']
with open(self.file_path, 'r') as fin:
for i, line in enumerate(fin):
if i < start: continue
if i >= end: return StopIteration()
items = line.strip().split('\t')
sample = {"id": items[id_c], "article": items[art_c], "summary": items[sum_c]}
yield sample
train_dataset = SummaryDataset(args.train_dataset)
train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=args.batch_size, num_workers=args.num_workers)
如果要配合DistributedDataParallel
进行多进程分布式训练,每个进程load不同段的数据,尽量不要在subprocess里通过num_workers
再开子进程,容易出问题,推荐以下写法:
class SummaryDataset(IterableDataset):
def __init__(self,
file_path: str,
rank,
world_size
):
super(SummaryDataset).__init__()
self.file_path = file_path
self.info = self._get_file_info(file_path)
self.start = self.info['start']
self.end = self.info['end']
self.rank = rank
self.world_size = world_size
self.per_worker = int(math.floor((self.end - self.start) / float(self.world_size)))
self.iter_start = self.start + self.rank * self.per_worker
self.iter_end = min(self.iter_start + self.per_worker, self.end)
def __iter__(self):
sample_iterator = self._sample_generator(self.iter_start, self.iter_end)
return sample_iterator
def __len__(self):
return self.iter_end - self.iter_start
def _get_file_info(self,
file_path
):
info = {
"start": 1,
"end": 0,
"id_colum": 0,
"article_colum": 1,
"summary_colum": 2
}
with open(file_path, 'r') as fin:
for _ in enumerate(fin):
info['end'] += 1
return info
def _sample_generator(self, start, end):
id_c, art_c, sum_c = self.info['id_colum'], self.info['article_colum'], self.info['summary_colum']
with open(self.file_path, 'r') as fin:
for i, line in enumerate(fin):
if i < start: continue
if i >= end: return StopIteration()
items = line.strip().split('\t')
sample = {"id": items[id_c], "article": items[art_c], "summary": items[sum_c]}
yield sample
def train_worker(rank, args):
# 子进程
...
train_dataset = SummaryDataset(args.train_dataset, rank, args.world_size)
rain_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=args.batch_size, num_workers=0)
...