batch_sampler (Sampler or Iterable, 可选)
sampler是返回Dataset所有数据的索引,batch_sampler是返回一个mini-batch数据的索引。与batch_size,shuffle,sampler和drop_last参数互斥。如果自定了BatchSampler,Dataloader则采用你定义的BatchSampler。如果不传或传入参数batch_sampler为None,Dataloader也实现了一个默认的BatchSampler:
class BatchSampler(Sampler[List[int]]):
r"""Wraps another sampler to yield a mini-batch of indices.
Args:
sampler (Sampler or Iterable): Base sampler. Can be any iterable object
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
Example:
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""
def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None:
# Since collections.abc.Iterable does not check for `__getitem__`, which
# is one way for an object to be an iterable, we don't do an `isinstance`
# check here.
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, "
"but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self) -> Iterator[List[int]]:
batch = []
for idx in self.sampler: # 遍历Sampler的返回的索引序列
batch.append(idx) # 将sampler返回的索引,添加到batch列表中
if len(batch) == self.batch_size: # batch列表长度等于batch_size则生成一个batch的索引数据
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self) -> int:
if self.drop_last:
return len(self.sampler) // self.batch_size # 是否drop_last
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type]
主要是关注两个方法:
__iter__()方法,通过遍历Sampler的返回的索引序列(line34, 长度=数据集长度),将其打包成一个个mini-batch,然后以迭代器的形式返回。batch_sampler每次返回的是一个mini-batch数据的索引(line36, 长度=batch_size)。
__len__()方法,根据drop_last参数,返回总共有多少个mini-batch。
嗨,欢迎大家关注我的公众号《CV之路》,一起讨论问题,一起学习进步~。也欢迎大家关注我的GitHub仓库,我出的所有博文教程都是无偿分享的,只求个关注与Star~,多谢大家支持!
GitHub - gy-7/CV-Road (后续教程相关所有代码都会维护到此仓库)