Dataloader可是个好东西,wrap一下就可以当作python generator使用,快捷、省内存,还能配合tqdm、trange等进度条,达到很好的观测效果。
下面附上本人使用的代码,可以快速生成一个dataloader:
def simple_data_loader(data: list, batch_size: int, random: bool):
''' create a naive data loader '''
dataset = myDataset(data)
sampler = RandomSampler(dataset) if random else SequentialSampler(dataset)
data_loader = DataLoader(dataset=dataset, batch_size=batch_size, sampler=sampler)
return data_loader
由于dataloader只是一个可迭代对象,无法使用.next()方法,所以我还利用torch0.3版本的历史代码写了一个迭代器版本的loader:
from utils.dataloader_iter import DataLoaderIter
class myDataIter(object):
''' a torch data_iterator '''
def __init__(self, data, batch_size, random=True):
self.data = data
self.random = random
self._data_iter = None
self._batch_size = batch_size
self._iteration = 0
self._reset = False # one epoch
def get_iteration(self):
return self._iteration
def if_reset(self):
return self._reset
def _build(self):
''' create a new DataLoaderIter object '''
# dataset = myDataset(self.data)
# sampler = RandomSampler(dataset) if self.random else SequentialSampler(dataset)
# data_loader = DataLoader(dataset=dataset, batch_size=self._batch_size, sampler=sampler)
data_loader = simple_data_loader(self.data, self._batch_size, self.random)
self._data_iter = DataLoaderIter(data_loader)
def next(self):
''' get next batch data
if the data has been taken out then initialize a new data_iterator
:return: a batch of data
'''
if self._data_iter is None:
warnings.warn("create data_loader_iter firstly")
self._build()
try:
batch = self._data_iter.next()
self._iteration += 1
return batch
except StopIteration:
self._build()
self._iteration = 1 # reset and return the 1st batch
self._reset = True
batch = self._data_iter.next()
return batch
由于这个需要用Dataiter这个class,而它早就在torch0.4之前就被废弃了,可以参考我的另一篇blog:torch.utils.data.dataloader.DataLoaderIter 无法导入问题
把Dataiter的问题解决。