Pytorch Dataloader 模块源码分析(二):Sampler / Fetcher 组件及 Dataloader 核心代码

本文深入剖析Pytorch Dataloader的内部组件,包括Sampler(SequentialSampler, RandomSampler, BatchSampler)和Fetcher的工作原理。Sampler负责生成访问Dataset的index,Fetcher对Dataset做封装并转换为Pytorch Tensor。Dataloader的单线程和多线程场景中,Fetcher的作用是减少I/O瓶颈,提高数据加载效率。通过理解和优化这些组件,能提升Pytorch模型训练的效率。" 51906253,1841921,Git 进阶:掌握Submodule使用,"['Git', '版本管理', 'Submodule', '团队协作', '代码管理']

Dataloader 组件

Sampler 类

在看 Sampler 的具体实现之前,我们先看看 Dataloader 在什么时候产生 Sampler 对象:

class DataLoader(object):
    def __init__(self, ...):
        ...
        if sampler is None:  
            ...
             # 如果指定shuffle就使用随机采样,否则使用顺序采样
                if shuffle: 
                    sampler = RandomSampler(dataset, generator=generator)
                else:
                    sampler = SequentialSampler(dataset)

        if batch_size is not None and batch_sampler is None:
            # 如果指定了batch_size又没有指定自定义的batch_sampler,就开启自动批采样
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)
        ...

我们可以看到 Sampler 对象的主要职责就是生成用于访问 Dataset 的 index。其中 Sampler 的子类如下:

  • SequentialSampler 顺序采样
  • RandomSampler 随机采样
  • BatchSampler 批采样

实际上还有其他的采样方法,但是因为使用的不多,本文主要讲解上述的三种 Sampler。上述提到的几种采样类都是 Sampler 的子类,Sampler 中的__iter__方法定义为 raise NotImplementedError:

class Sampler(Generic[T_co]):
    def __init__(self, data_source: Optional[Sized]) -> None:
        pass
    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError

SequentialSampler

SequentialSampler 实现:

class SequentialSampler(Sampler[int]):
    data_source: Sized

    def __init__(self, data_source: Sized) -> None:
        self.data_source = data_source

    def __iter__(self) -> Iterator[int]:
    	# 创建一个迭代器
        return iter(range(len(self.data_source)))

    def __len__(self) -> int:
        return len(self.data_source)

这里主要关注__Iter__方法,实际上返回的 index 就是 range(len(self.data_source)) 顺序递增的结果:len(data_source) 实际上就是 Dataset 返回的 samples 的长度。创建迭代器之后,当对这个迭代器调用__next__方法,就会返回 0, 1, 2, 3, 4, … 顺序递增的 index。

RandomSampler

RandomSampler 实现:

class RandomSampler(Sampler[int]):
    data_source: Sized
    replacement: bool

    def __init__(self, data_source: Sized, replacement: bool = False,
                 num_samples: Optional[int] = None, generator=None) -> None:
        self.data_source = data_source
        self.replacement = replacement
        self._num_samples = num_samples
        self.generator = generator    
		...
    @property
    def num_samples(self) -> int:
        # dataset size might change at runtime
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples

    def __iter__(self) -> Iterator[int]:
        n = len(self.data_source)
        if self.generator is None:
            seed = int(torch.empty((), dtype=torch.int64).random_().item())
            generator = torch.Generator()
            generator.manual_seed(seed)
        else:
            generator = self.generator
		# replacement 表示是否可以生成重复 index
        if self.replacement:
        	# num_samples 表示一次性采样的数据量
            for _ in range(self.num_samples // 32):
                yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
            yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
        else:
            for _ in range(self.num_samples // n):
                yield from torch.randperm(n, generator=generator)</
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值