Pytorch Dataloader之batch_sampler

batch_sampler (Sampler or Iterable, 可选)

sampler是返回Dataset所有数据的索引,batch_sampler是返回一个mini-batch数据的索引。与batch_size,shuffle,sampler和drop_last参数互斥。如果自定了BatchSampler,Dataloader则采用你定义的BatchSampler。如果不传或传入参数batch_sampler为None,Dataloader也实现了一个默认的BatchSampler

class BatchSampler(Sampler[List[int]]):
    r"""Wraps another sampler to yield a mini-batch of indices.

    Args:
        sampler (Sampler or Iterable): Base sampler. Can be any iterable object
        batch_size (int): Size of mini-batch.
        drop_last (bool): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``

    Example:
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
    """

    def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None:
        # Since collections.abc.Iterable does not check for `__getitem__`, which
        # is one way for an object to be an iterable, we don't do an `isinstance`
        # check here.
        if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
                batch_size <= 0:
            raise ValueError("batch_size should be a positive integer value, "
                             "but got batch_size={}".format(batch_size))
        if not isinstance(drop_last, bool):
            raise ValueError("drop_last should be a boolean value, but got "
                             "drop_last={}".format(drop_last))
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self) -> Iterator[List[int]]:
        batch = []
        for idx in self.sampler:    # 遍历Sampler的返回的索引序列
            batch.append(idx)    # 将sampler返回的索引,添加到batch列表中
            if len(batch) == self.batch_size:   # batch列表长度等于batch_size则生成一个batch的索引数据
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self) -> int:
        if self.drop_last:
            return len(self.sampler) // self.batch_size  # 是否drop_last
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size  # type: ignore[arg-type]

  主要是关注两个方法:

  1. __iter__()方法,通过遍历Sampler的返回的索引序列(line34, 长度=数据集长度),将其打包成一个个mini-batch,然后以迭代器的形式返回。batch_sampler每次返回的是一个mini-batch数据的索引(line36, 长度=batch_size)。

  2. __len__()方法,根据drop_last参数,返回总共有多少个mini-batch。


嗨,欢迎大家关注我的公众号《CV之路》,一起讨论问题,一起学习进步~。也欢迎大家关注我的GitHub仓库,我出的所有博文教程都是无偿分享的,只求个关注与Star~,多谢大家支持!

GitHub - gy-7/CV-Road (后续教程相关所有代码都会维护到此仓库)

 

  • 5
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

gy-7

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值