pytorch DistributedDataParallel

 

1. 为什么每个epoch都要调用distributed_sampler.set_epoch(epoch) 函数一次?

torch.utils.data.distributed.DistributedSampler: 在多机多卡情况下分布式训练数据的读取也是一个问题,不同的卡读取到的数据应该是不同的。dataparallel的做法是直接将batch切分到不同的卡,这种方法对于多机来说不可取,因为多机之间直接进行数据传输会严重影响效率。于是有了利用distributed_sampler确保dataloader只会load到整个数据集的一个特定子集的做法。DistributedSampler为每一个子进程划分出一部分数据集,以避免不同进程之间数据重复。

# 分布式训练示例
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel

dataset = your_dataset()
datasampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = DataLoader(dataset, batch_size=batch_size_per_gpu, sampler=datasampler)
model = your_model()
model = DistributedDataPrallel(model, device_ids=[local_rank], output_device=local_rank)

但是DistributedSampler 会把数据划分成 num_gpu 份,不同的 GPU 拿自己那一份,那么训练过程中每个 GPU 只能见到自己那一份数据,不能看到别的 GPU 的数据。而且尽管 DistributedSampler 默认是 shuffle=True,但是每个epoch和每次运行(重新运行整个程序),epoch之间每个 GPU 的输出顺序是一样的,没有 shuffle 成功。问题就在于DistributedSampler对子集的划分只有一次,而且依赖于 seed ,参考源码 DistributedSampler , 可见g.manual_seed(self.epoch) 中的seed就是self.epoch, 默认self.epoch=0,如果self.epoch 不更新,产生的序列就是固定的。

 def __iter__(self):
        # deterministically shuffle based on epoch
        g = torch.Generator()
        g.manual_seed(self.epoch)
        if self.shuffle:
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = list(range(len(self.dataset)))


        # add extra samples to make it evenly divisible
        indices += indices[:(self.total_size - len(indices))]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)
    def __init__(dataset, num_replicas, rank=0, shuffle=True)
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle

所以 DistributedSampler 也提供了一个 set 函数来改变 self.epoch

def set_epoch(self, epoch):
    self.epoch = epoch

所以在运行的时候要不断调用这个 set_epoch 函数,

distributed_sampler.set_epoch(epoch)

 

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值