解读torch的 DistributedSampler

通常与分布式DDP配合使用,用作dataloader的sampler,顾名思义,数据采集器决定dataloder在当前batch选择哪些数据,DistributedSampler就用于在分布式训练中对数据集进行划分和采样

数据集被认为是恒定大小的,并且它的任何实例总是以相同的顺序返回相同的元素。

参数解读

Args:
        dataset: 用于采样的数据集。
        num_replicas(int,可选): 参与分布式训练的进程的数量。默认情况下,:attr:`world_size`会从当前的分布式组中获取。
        rank (int, optional): 当前进程在 :attr:`num_replicas`中的排名。默认情况下,:attr:`rank`会从当前的分布式组中检索。
        shuffle (bool, optional): 如果`True`(默认),采样器将洗牌索引。
        seed (int, optional): 如果 :attr:`shuffle=True`,随机的种子用于洗牌采样器。这个数字在分布式组的所有进程中应该是相同的。默认值: `0```。
        drop_last (bool, optional): 如果`True`,那么采样器将放弃数据的尾部,以使其在副本的数量上均匀地分配。如果`False`,采样器将增加额外的索引,使数据在副本中平均分配。默认值: `False'.

示范

        >>> # xdoctest: +SKIP
        >>> sampler = DistributedSampler(dataset) if is_distributed else None
        >>> loader = DataLoader(dataset, shuffle=(sampler is None),
        ...                     sampler=sampler)
        >>> for epoch in range(start_epoch, n_epochs):
        ...     if is_distributed:
        ...         sampler.set_epoch(epoch)
        ...     train(loader)

     在分布式模式下,在每个周期开始时调用`set_epoch` 方法
在创建 :class:`DataLoader` 迭代器之前,在每个纪元的开始调用`set_epoch`方法是必要的,以使洗牌在多个epoch中正常工作。否则相同的排序将总是被使用。

 且每个进程的seed应该一样,以保证每个进程的smpler会配合遍历这个数据集

init

接下来看代码来讲解

    def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
                 rank: Optional[int] = None, shuffle: bool = True,
                 seed: int = 0, drop_last: bool = False) -> None:
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(
                "Invalid rank {}, rank should be in the interval"
                " [0, {}]".format(rank, num_replicas - 1))
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.drop_last = drop_last
        # If the dataset length is evenly divisible by # of replicas, then there
        # is no need to drop any data, since the dataset will be split equally.
        if self.drop_last and len(self.dataset) % self.num_replicas != 0:  # type: ignore[arg-type]
            # Split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            self.num_samples = math.ceil(
                (len(self.dataset) - self.num_replicas) / self.num_replicas  # type: ignore[arg-type]
            )
        else:
            self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)  # type: ignore[arg-type]
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle
        self.seed = seed

这段代码定义了一个类,该类用于生成数据集的采样器(Sampler)。以下是对代码进行详细解释:

构造函数 __init__ 接受多个参数:dataset(数据集),num_replicas(副本数量,默认为 None),rank(当前副本的编号,默认为 None),shuffle(是否在每个 epoch 之前打乱数据,默认为 True),seed(随机种子,默认为 0),drop_last(是否丢弃最后一部分数据,默认为 False)。详见上面的参数讲解

首先,通过判断 dist 包是否可用来确定是否使用分布式训练。如果不可用,则抛出运行时错误。

接下来,通过检查 num_replicasrank 的值来确定当前副本的数量和编号。如果未指定 num_replicas 或者 rank,且分布式包可用,则使用分布式包提供的函数获取当前副本的数量和编号。

然后,检查 rank 是否在合理范围内,即[0, num_replicas - 1]。如果超出范围,则抛出值错误异常。

接下来,将传入的数据集、副本数量、副本编号、当前 epoch 数等赋值给相应的成员变量。

根据是否设置了 drop_last,以及数据集的长度是否能够被副本数量整除,来确定总共需要采样的数量。如果设置了 drop_last 且数据集的长度不能被副本数量整除,则采样的数量为 (len(dataset) - num_replicas) / num_replicas,并使用 math.ceil 函数向上取整。否则,采样的数量为 len(dataset) / num_replicas

计算总共的样本数量为采样数量乘以副本数量。

最后,将是否需要打乱数据、随机种子等赋值给相应的成员变量。

len与setepoch方法


    def __len__(self) -> int:
        return self.num_samples

    def set_epoch(self, epoch: int) -> None:
        r"""
        Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
        use a different random ordering for each epoch. Otherwise, the next iteration of this
        sampler will yield the same ordering.

        Args:
            epoch (int): Epoch number.
        """
        self.epoch = epoch

其他sampler的__len__()只负责返回数据集dataset包含的数据个数;

这里的sampler返回的是本进程划负责的数据集长度

self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) 见上面的init方法中对num_samples的定义

iter方法

    def __iter__(self) -> Iterator[T_co]:
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignore[arg-type]
        else:
            indices = list(range(len(self.dataset)))  # type: ignore[arg-type]

        if not self.drop_last:
            # add extra samples to make it evenly divisible
            padding_size = self.total_size - len(indices)
            if padding_size <= len(indices):
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
        else:
            # remove tail of data to make it evenly divisible.
            indices = indices[:self.total_size]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

对于所有的采样器来说,都需要继承Sampler类,必须实现的方法为__iter__(),也就是定义迭代器行为,返回可迭代对象。除此之外,Sampler类并没有定义任何其它的方法。

__iter__()方法负责返回一个可迭代对象,实现了对数据集的采样和划分。

首先,根据是否设置了 shuffle(是否打乱数据)来判断是否需要对数据进行打乱。如果需要打乱数据,则通过创建一个随机数生成器 torch.Generator(),并基于当前 epoch 和随机种子 seed 对其进行设置。然后使用生成器生成一个长度为数据集长度的随机排列索引 indices。如果不需要打乱数据,则直接将索引列表初始化为数据集的索引列表。

接下来,根据是否设置了 drop_last(是否丢弃最后一部分数据)来决定是否需要增加或删除样本以使得总样本数量能够被副本数量整除。

如果没有设置 drop_last,则需要添加额外的样本以使总样本数量能够被副本数量整除。首先计算需要添加的填充样本数量 padding_size,即总样本数量与当前索引列表长度之差。然后根据两种情况进行判断:如果 padding_size 小于等于索引列表的长度,说明可以从索引列表中复制 padding_size 个样本来填充,所以将列表的前 padding_size 个元素复制并追加到索引列表的末尾;如果 padding_size 大于索引列表的长度,说明无法仅从索引列表中复制足够数量的样本来填充,所以将索引列表重复拼接,并且拼接的次数是通过 math.ceil(padding_size / len(indices)) 来计算的(向上取整),然后再从中截取前 padding_size 个元素作为填充样本。这样做可以确保总样本数量能够被副本数量整除。

如果设置了 drop_last,则需要删除多余的样本使总样本数量能够被副本数量整除,即将索引列表保留前 total_size 个元素。

最后,根据当前副本编号 rank、总样本数量 total_size 和副本数量 num_replicas 对索引列表进行划分和子采样。使用切片操作 indices[self.rank:self.total_size:self.num_replicas],从索引列表中选择属于当前副本的样本索引。确保最终子采样后的样本数量等于 num_samples

最后,返回一个迭代器对象,该迭代器对应着采样后的样本索引列表,并可以通过 for 循环迭代获取每个样本的索引。

测试

虚拟一个数据集

dataset = list([6, 36, 9, 26, 17])

手动创建两个分布式数据采集器

sampler0 = DistributedSampler(dataset,num_replicas=2,rank=0)

sampler1 = DistributedSampler(dataset,num_replicas=2,rank=1)

打印出来

for index in sampler0:

    print("index: {}, data: {}".format(str(index), str(dataset[index])))

结果

sampler0-----index: 4, data: 17 index: 1, data: 36 index: 2, data: 9

sampler1-----index: 0, data: 6 index: 3, data: 26 index: 4, data: 17

拆解开看看

g = torch.Generator()

g.manual_seed(0)

lenth = len(dataset)#5

indices = torch.randperm(lenth, generator=g).tolist()#[4, 0, 1, 3, 2]

结果:

sampler1.total_size#6

sampler1.num_replicas#2

#indices = indices[self.rank:self.total_size:self.num_replicas]

indices = indices[0:6:2]或者indices = indices[1:6:2] 不同sampler的rank不一样

indices#[4, 1, 2]

和前面一致

# SubsetRandomSampler

和这个很像

class SubsetRandomSampler(Sampler[int]):
    r"""Samples elements randomly from a given list of indices, without replacement.

    Args:
        indices (sequence): a sequence of indices
        generator (Generator): Generator used in sampling.
    """
    indices: Sequence[int]

    def __init__(self, indices: Sequence[int], generator=None) -> None:
        self.indices = indices
        self.generator = generator

    def __iter__(self) -> Iterator[int]:
        for i in torch.randperm(len(self.indices), generator=self.generator):
            yield self.indices[i]

    def __len__(self) -> int:
        return len(self.indices)

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值