pytorch分布式训练指北
第一章节的部分是简单的科普,想看如何在本地及docker内跑pytorch分布式的直接看第二章。
1、pytorch分布式代码基础
1.1、如何写pytorch的分布式代码
这个部分大概讲一下如何写分布式的Pytorch代码,首先,官方pytorch(v1.0.10)在分布式上给出的api有这么两个非常重要的
,需要使用的:
torch.nn.parallel.DistributedDataParallel
这个api和DataParallel相类似,也是一个模型wrapper。这个api可以帮助我们在不同机器的多个模型拷贝之间平均梯度。
torch.utils.data.distributed.DistributedSampler
在多机多卡情况下分布式训练数据的读取也是一个问题,不同的卡读取到的数据应该是不同的。dataparallel的做法是直接
将batch切分到不同的卡,这种方法对于多机来说不可取,因为多机之间直接进行数据传输会严重影响效率。于是有了利用sampler
确保dataloader只会load到整个数据集的一个特定子集的做法。DistributedSampler就是做这件事的。它为每一个子进程划分
出一部分数据集,以避免不同进程之间数据重复。
到这里要是还没看明白,那就建议看看谷歌。下面给出Pytorch的代码如何改成分布式代码:
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel
dataset = your_dataset()
datasampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = DataLoader(dataset, batch_size=batch_size_per_gpu, sampler=datasampler)
model = your_model()
##这个部分东西比较多,文章的后面我稍微做一些补充,细节可以去谷歌
model = DistributedDataPrallel(model, device_ids=[local_rank], output_device=local_rank)
1.2、各种参数介绍
想要使用DistributedD