通常与分布式DDP配合使用,用作dataloader的sampler,顾名思义,数据采集器决定dataloder在当前batch选择哪些数据,DistributedSampler就用于在分布式训练中对数据集进行划分和采样
数据集被认为是恒定大小的,并且它的任何实例总是以相同的顺序返回相同的元素。
参数解读
Args:
dataset: 用于采样的数据集。
num_replicas(int,可选): 参与分布式训练的进程的数量。默认情况下,:attr:`world_size`会从当前的分布式组中获取。
rank (int, optional): 当前进程在 :attr:`num_replicas`中的排名。默认情况下,:attr:`rank`会从当前的分布式组中检索。
shuffle (bool, optional): 如果`True`(默认),采样器将洗牌索引。
seed (int, optional): 如果 :attr:`shuffle=True`,随机的种子用于洗牌采样器。这个数字在分布式组的所有进程中应该是相同的。默认值: `0```。
drop_last (bool, optional): 如果`True`,那么采样器将放弃数据的尾部,以使其在副本的数量上均匀地分配。如果`False`,采样器将增加额外的索引,使数据在副本中平均分配。默认值: `False'.
示范
>>> # xdoctest: +SKIP
>>> sampler = DistributedSampler(dataset) if is_distributed else None
>>> loader = DataLoader(dataset, shuffle=(sampler is None),
... sampler=sampler)
>>> for epoch in range(start_epoch, n_epochs):
... if is_distributed:
... sampler.set_epoch(epoch)
... train(loader)
在分布式模式下,在每个周期开始时调用`set_epoch` 方法
在创建 :class:`DataLoader` 迭代器之前,在每个纪元的开始调用`set_epoch`方法是必要的,以使洗牌在多个epoch中正常工作。否则相同的排序将总是被使用。
且每个进程的seed应该一样,以保证每个进程的smpler会配合遍历这个数据集
init
接下来看代码来讲解
def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
rank: Optional[int] = None, shuffle: bool = True,
seed: int = 0, drop_last: bool = False) -> None:
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
if rank >= num_replicas or rank < 0:
raise ValueError(
"Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1))
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
这段代码定义了一个类,该类用于生成数据集的采样器(Sampler)。以下是对代码进行详细解释:
构造函数 __init__
接受多个参数:dataset
(数据集),num_replicas
(副本数量,默认为 None),rank
(当前副本的编号,默认为 None),shuffle
(是否在每个 epoch 之前打乱数据,默认为 True),seed
(随机种子,默认为 0),drop_last
(是否丢弃最后一部分数据,默认为 False)。详见上面的参数讲解
首先,通过判断 dist
包是否可用来确定是否使用分布式训练。如果不可用,则抛出运行时错误。
接下来,通过检查 num_replicas
和 rank
的值来确定当前副本的数量和编号。如果未指定 num_replicas
或者 rank
,且分布式包可用,则使用分布式包提供的函数获取当前副本的数量和编号。
然后,检查 rank
是否在合理范围内,即[0, num_replicas - 1]。如果超出范围,则抛出值错误异常。
接下来,将传入的数据集、副本数量、副本编号、当前 epoch 数等赋值给相应的成员变量。
根据是否设置了 drop_last
,以及数据集的长度是否能够被副本数量整除,来确定总共需要采样的数量。如果设置了 drop_last
且数据集的长度不能被副本数量整除,则采样的数量为 (len(dataset) - num_replicas) / num_replicas
,并使用 math.ceil
函数向上取整。否则,采样的数量为 len(dataset) / num_replicas
。
计算总共的样本数量为采样数量乘以副本数量。
最后,将是否需要打乱数据、随机种子等赋值给相应的成员变量。
len与setepoch方法
def __len__(self) -> int:
return self.num_samples
def set_epoch(self, epoch: int) -> None:
r"""
Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
use a different random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
其他sampler的__len__()只负责返回数据集dataset包含的数据个数;
这里的sampler返回的是本进程划负责的数据集长度
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) 见上面的init方法中对num_samples的定义
iter方法
def __iter__(self) -> Iterator[T_co]:
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
对于所有的采样器来说,都需要继承Sampler类,必须实现的方法为__iter__(),也就是定义迭代器行为,返回可迭代对象。除此之外,Sampler类并没有定义任何其它的方法。
__iter__()方法负责返回一个可迭代对象,实现了对数据集的采样和划分。
首先,根据是否设置了 shuffle
(是否打乱数据)来判断是否需要对数据进行打乱。如果需要打乱数据,则通过创建一个随机数生成器 torch.Generator()
,并基于当前 epoch 和随机种子 seed
对其进行设置。然后使用生成器生成一个长度为数据集长度的随机排列索引 indices
。如果不需要打乱数据,则直接将索引列表初始化为数据集的索引列表。
接下来,根据是否设置了 drop_last
(是否丢弃最后一部分数据)来决定是否需要增加或删除样本以使得总样本数量能够被副本数量整除。
如果没有设置 drop_last
,则需要添加额外的样本以使总样本数量能够被副本数量整除。首先计算需要添加的填充样本数量 padding_size
,即总样本数量与当前索引列表长度之差。然后根据两种情况进行判断:如果 padding_size
小于等于索引列表的长度,说明可以从索引列表中复制 padding_size
个样本来填充,所以将列表的前 padding_size
个元素复制并追加到索引列表的末尾;如果 padding_size
大于索引列表的长度,说明无法仅从索引列表中复制足够数量的样本来填充,所以将索引列表重复拼接,并且拼接的次数是通过 math.ceil(padding_size / len(indices))
来计算的(向上取整),然后再从中截取前 padding_size
个元素作为填充样本。这样做可以确保总样本数量能够被副本数量整除。
如果设置了 drop_last
,则需要删除多余的样本使总样本数量能够被副本数量整除,即将索引列表保留前 total_size
个元素。
最后,根据当前副本编号 rank
、总样本数量 total_size
和副本数量 num_replicas
对索引列表进行划分和子采样。使用切片操作 indices[self.rank:self.total_size:self.num_replicas]
,从索引列表中选择属于当前副本的样本索引。确保最终子采样后的样本数量等于 num_samples
。
最后,返回一个迭代器对象,该迭代器对应着采样后的样本索引列表,并可以通过 for 循环迭代获取每个样本的索引。
测试
虚拟一个数据集
dataset = list([6, 36, 9, 26, 17])
手动创建两个分布式数据采集器
sampler0 = DistributedSampler(dataset,num_replicas=2,rank=0)
sampler1 = DistributedSampler(dataset,num_replicas=2,rank=1)
打印出来
for index in sampler0:
print("index: {}, data: {}".format(str(index), str(dataset[index])))
结果
sampler0-----index: 4, data: 17 index: 1, data: 36 index: 2, data: 9
sampler1-----index: 0, data: 6 index: 3, data: 26 index: 4, data: 17
拆解开看看
g = torch.Generator()
g.manual_seed(0)
lenth = len(dataset)#5
indices = torch.randperm(lenth, generator=g).tolist()#[4, 0, 1, 3, 2]
结果:
sampler1.total_size#6
sampler1.num_replicas#2
#indices = indices[self.rank:self.total_size:self.num_replicas]
indices = indices[0:6:2]或者indices = indices[1:6:2] 不同sampler的rank不一样
indices#[4, 1, 2]
和前面一致
# SubsetRandomSampler
和这个很像
class SubsetRandomSampler(Sampler[int]):
r"""Samples elements randomly from a given list of indices, without replacement.
Args:
indices (sequence): a sequence of indices
generator (Generator): Generator used in sampling.
"""
indices: Sequence[int]
def __init__(self, indices: Sequence[int], generator=None) -> None:
self.indices = indices
self.generator = generator
def __iter__(self) -> Iterator[int]:
for i in torch.randperm(len(self.indices), generator=self.generator):
yield self.indices[i]
def __len__(self) -> int:
return len(self.indices)