PyTorch分布式数据加载学习 DistributedSampler

[源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler - 罗西的思考 - 博客园

初始化

class DistributedSampler(Sampler[T_co]):
    def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
                 rank: Optional[int] = None, shuffle: bool = True,
                 seed: int = 0, drop_last: bool = False) -> None:
        # If the dataset length is evenly divisible by # of replicas, then there
        # is no need to drop any data, since the dataset will be split equally.
        if self.drop_last and len(self.dataset) % self.num_replicas != 0:
            # Split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            self.num_samples = math.ceil(
                (len(self.dataset) - self.num_replicas) / self.num_replicas
            )
        else:
            self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # 向上取整
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle
        self.seed = seed

如果不drop_last,那就

self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)

向上取整。

迭代

在迭代的时候,如果不能整除,那就把indices的前几个样本复制一遍:

    def __iter__(self) -> Iterator[T_co]:
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = list(range(len(self.dataset)))

        if not self.drop_last:  # 一般进入这里,不会丢掉剩下的训练数据
            # add extra samples to make it evenly divisible
            padding_size = self.total_size - len(indices)
            if padding_size <= len(indices):
                indices += indices[:padding_size]   # 把indices的前几个复制一次
            else:
                indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
        else:
            # remove tail of data to make it evenly divisible.
            indices = indices[:self.total_size]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

最关键的还是这行:

indices = indices[self.rank:self.total_size:self.num_replicas]

规定了每个rank的取数据的索引,起始索引是rank,每间隔num_replicas取一个

shuffle数据集

每次epoch都会shuffle数据集,但是不同进程如何保持shuffle之后数据集一致性

DistributedSampler 使用当前的epoch作为随机数种子,在计算index之前就进行配置,从而保证不同进程都使用同样的随机数种子,这样shuffle出来的数据就能确保一致。

sampler = DistributedSampler(dataset) if is_distributed else None
loader = DataLoader(dataset, shuffle=(sampler is None), ...,
                            sampler=sampler)
	for epoch in range(start_epoch, n_epochs):
    	if is_distributed:
        	sampler.set_epoch(epoch) # 这设置epoch
        train(loader)

设置 random 种子的具体使用是在迭代函数之中:

    def __iter__(self) -> Iterator[T_co]:
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch) # 这里设置随机种子
            indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignore[arg-type]
        else:
            indices = list(range(len(self.dataset)))  # type: ignore[arg-type]

在 PyTorch 中,torch.randperm(n) 函数用于生成一个从 0n-1 的随机排列的整数序列。这个函数是非常有用的,尤其是在需要随机打乱数据或索引时,比如在训练机器学习模型时打乱数据顺序,以确保模型训练的泛化能力。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值