pytorch dataloader食用指南,附上个人使用代码

Dataloader可是个好东西,wrap一下就可以当作python generator使用,快捷、省内存,还能配合tqdm、trange等进度条,达到很好的观测效果。

下面附上本人使用的代码,可以快速生成一个dataloader:

def simple_data_loader(data: list, batch_size: int, random: bool):
    ''' create a naive data loader '''
    dataset = myDataset(data)
    sampler = RandomSampler(dataset) if random else SequentialSampler(dataset)
    data_loader = DataLoader(dataset=dataset, batch_size=batch_size, sampler=sampler)

    return data_loader

由于dataloader只是一个可迭代对象,无法使用.next()方法,所以我还利用torch0.3版本的历史代码写了一个迭代器版本的loader:

from utils.dataloader_iter import DataLoaderIter

class myDataIter(object):
    '''  a torch data_iterator  '''

    def __init__(self, data, batch_size, random=True):
        self.data = data
        self.random = random
        self._data_iter = None
        self._batch_size = batch_size
        self._iteration = 0
        self._reset = False  # one epoch

    def get_iteration(self):
        return self._iteration

    def if_reset(self):
        return self._reset

    def _build(self):
        ''' create a new DataLoaderIter object '''
        # dataset = myDataset(self.data)
        # sampler = RandomSampler(dataset) if self.random else SequentialSampler(dataset)
        # data_loader = DataLoader(dataset=dataset, batch_size=self._batch_size, sampler=sampler)
        data_loader = simple_data_loader(self.data, self._batch_size, self.random)
        self._data_iter = DataLoaderIter(data_loader)

    def next(self):
        ''' get next batch data
        if the data has been taken out then initialize a new data_iterator
        :return: a batch of data
        '''
        if self._data_iter is None:
            warnings.warn("create data_loader_iter firstly")
            self._build()
        try:
            batch = self._data_iter.next()
            self._iteration += 1

            return batch
        except StopIteration:
            self._build()
            self._iteration = 1  # reset and return the 1st batch
            self._reset = True

            batch = self._data_iter.next()

            return batch

由于这个需要用Dataiter这个class,而它早就在torch0.4之前就被废弃了,可以参考我的另一篇blog:torch.utils.data.dataloader.DataLoaderIter 无法导入问题
把Dataiter的问题解决。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值