1. 为什么每个epoch都要调用distributed_sampler.set_epoch(epoch) 函数一次?
torch.utils.data.distributed.DistributedSampler: 在多机多卡情况下分布式训练数据的读取也是一个问题,不同的卡读取到的数据应该是不同的。dataparallel的做法是直接将batch切分到不同的卡,这种方法对于多机来说不可取,因为多机之间直接进行数据传输会严重影响效率。于是有了利用distributed_sampler确保dataloader只会load到整个数据集的一个特定子集的做法。DistributedSampler为每一个子进程划分出一部分数据集,以避免不同进程之间数据重复。
# 分布式训练示例
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel
dataset = your_dataset()
datasampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = DataLoader(dataset, batch_size=batch_size_per_gpu, sampler=datasampler)
model = your_model()
model = DistributedDataPrallel(model, device_ids=[local_rank], output_device=local_rank)
但是DistributedSampler 会把数据划分成 num_gpu 份,不同的 GPU 拿自己那一份,那么训练过程中每个 GPU 只能见到自己那一份数据,不能看到别的 GPU 的数据。而且尽管 DistributedSampler 默认是 shuffle=True,但是每个epoch和每次运行(重新运行整个程序),epoch之间每个 GPU 的输出顺序是一样的,没有 shuffle 成功。问题就在于DistributedSampler对子集的划分只有一次,而且依赖于 seed ,参考源码 DistributedSampler , 可见g.manual_seed(self.epoch) 中的seed就是self.epoch, 默认self.epoch=0,如果self.epoch 不更新,产生的序列就是固定的。
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __init__(dataset, num_replicas, rank=0, shuffle=True)
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
所以 DistributedSampler 也提供了一个 set 函数来改变 self.epoch
def set_epoch(self, epoch):
self.epoch = epoch
所以在运行的时候要不断调用这个 set_epoch 函数,
distributed_sampler.set_epoch(epoch)