pytorch Dataloader Sampler参数深入理解

DataLoader函数

  • 参数与初始化

 def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=None,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None, multiprocessing_context=None):

其中几个常用的参数

  1. dataset 数据集,map-style and iterable-style 可以用index取值的对象、
  2. batch_size 大小
  3. shuffle 取batch是否随机取, 默认为False
  4. sampler 定义取batch的方法,是一个迭代器, 每次生成一个key 用于读取dataset中的值
  5. batch_sampler 也是一个迭代器, 每次生次一个batch_size的key
  6. num_workers 参与工作的线程数
  7. collate_fn 对取出的batch进行处理
  8. drop_last 对最后不足batchsize的数据的处理方法

下面看两段取自DataLoader中的__init__代码, 帮助我们理解几个常用参数之间的关系

	if sampler is None:  # give default samplers
	    if self._dataset_kind == _DatasetKind.Iterable:
	        # See NOTE [ Custom Samplers and IterableDataset ]
	        sampler = _InfiniteConstantSampler()
	    else:  # map-style
	        if shuffle:
	            sampler = RandomSampler(dataset)
	        else:
	            sampler = SequentialSampler(dataset)

可以看出, 当dataset类型是map style时, shuffle其实就是改变sampler的取值

  • shuffle为默认值 False时,sampler是SequentialSampler,就是按顺序取样,
  • shuffle为True时,sampler是RandomSampler, 就是按随机取样

所以当我们sampler有输入时,shuffle的值就没有意义,后面我们再看sampler的定义方法

再看一段初始化代码

    if batch_size is not None and batch_sampler is None:
        # auto_collation without custom batch_sampler
        batch_sampler = BatchSampler(sampler, batch_size, drop_last)
        
    self.sampler = sampler
    self.batch_sampler = batch_sampler

再看看,BatchSampler的生成过程

# 略去类的初始化
    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

就是按batch_size从sampler中读取索引, 并形成生成器返回。

以上可以看出, batch_sampler和sampler, batch_size, drop_last之间的关系

  • 如果batch_sampler没有定义的话且batch_size有定义, 会根据sampler, batch_size, drop_last生成一个batch_sampler
  • 自带的注释中对batch_sampler有一句话: Mutually exclusive with :attr:batch_size :attr:shuffle, :attr:sampler, and :attr:drop_last.
  • 意思就是batch_sampler 与这些参数冲突 ,即 如果你定义了batch_sampler, 其他参数都不需要有

再看batch的生成过程

每个batch都是由迭代器产生的

# DataLoader中iter的部分
    def __iter__(self):
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            return _MultiProcessingDataLoaderIter(self)

# 再看调用的另一个类
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def __next__(self):
        index = self._next_index()  
        data = self._dataset_fetcher.fetch(index)  
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

对上面的代码进行一一解读, 初始化略过

  • 先对_next_index()一步步溯源

    def _next_index(self):
        return next(self._sampler_iter) 
    ///
	self._sampler_iter = iter(self._index_sampler)  
# 以上又用了一个迭代器生成索引 
 	self._index_sampler = loader._index_sampler
 	///
    def _index_sampler(self):
        if self._auto_collation:
            return self.batch_sampler
        else:
            return self.sampler 

可以看出, _next_index其实就是batch_sample或是sampler用迭代器生成了一遍。
而sampler返回的就dataset中对应的索引值

  • 再看 _dataset_fetcher函数
 def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
        if kind == _DatasetKind.Map:
            return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
        else:
            return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)

# 按map-style往下看
class _MapDatasetFetcher(_BaseDatasetFetcher):
    
	# 略过初始化
    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = [self.dataset[idx] for idx in possibly_batched_index]
            # 关键
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)

可以看到DataLoader中迭代器生成的data, 就是根据sampler或者batch_sampler生成的索引,从dataset中取的值, 然后经过collate_fn的处理。
这里还要提一下auto_collation参数:

    def _auto_collation(self):
        return self.batch_sampler is not None

其实就是判断batch_sampler 是否为None的情况, 而根据batch_sampler的定义, 只有,初始化参数batch_size和batch_sampler都为None时,才为False。这时, 从fetch函数可以看出,就是每次取一个值, _next_index()取的也是sampler, 此时相当与batch_size等于1。
由此,明白了整个大致过程,我们就可以对sampler进行定义,来获得我们想要的batch

  • sampler 参数的使用

sampler 是用来定义取batch方法的一个函数或者类,返回的是一个迭代器。

  • 我们可以看下自带的RandomSampler类中最重要的iter函数
    def __iter__(self):
        n = len(self.data_source)
        # dataset的长度, 按顺序索引
        if self.replacement:# 对应的replace参数
            return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
        return iter(torch.randperm(n).tolist())        

可以看出,其实就是生成索引,然后随机的取值, 然后再迭代。
其实还有一些细节需要注意理解:

  • 比如__len__函数,包括DataLoader的len和sample的len, 两者区别, 这部分代码比较简单,可以自行阅读,其实参考着RandomSampler写也不会出现问题。
  • 比如,迭代器和生成器的使用, 以及区别

附上最近读的reid的代码中涉及sampler的部分

  • 关于dataset预处理和collate_fn的一些问题

在dataset预处理中,曾遇到这样一个问题:interrupted by signal 9: SIGKILL
经过查询才知道,是内存溢出。 后来经过查看以上链接中michuanhaohao的代码才发现问题, 预处理并不是对整个dataset同时进行预处理,然后传入DataLoader,而是把raw_dataset直接传入DataLoader, 当读取一个batch时,对batch进行处理,这样确实节省内存。
附上michuanhaohao代码:

class ImageDataset(Dataset):
    """Image Person ReID Dataset"""

    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        img_path, pid, camid = self.dataset[index]
        img = read_image(img_path)

        if self.transform is not None:
            img = self.transform(img)

        return img, pid, camid, img_path

之前说过,DataLoader的参数dataset类型可以是列表,字典等, 可以用索引去读取值的类型。
此处,则是一个类, 类中定义了__getitem__函数, 使其能够用index去取值。
这个类初始化输入的dataset其实是图片地址和id参数,而当有index来访问时, 再去读取一个batch的图片,然后再对图片进行transform, 可以节省内存。

collate_fn 函数,可以从上面的fetch部分中看到, 也是对读取到的batch进行处理的一个对象,所以,预处理实际上也可以放在collate_fn中。

  • 46
    点赞
  • 73
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值