本文来源于:Pytorch Distributed - 知乎
Pytorch Distributed
现在的模型越来越大,并行显得越来越重要,而众所周知,pytorch 的并行文档写的非常不清楚,不仅影响使用,甚至我们都不知道他的工作原理。一次偶然的机会,我发现了几篇在这方面写的很好的文章,因此也准备参考别人的(参考的文章在Reference部分列出)再结合自己的使用经验总结一下。
nn.DataParallel
Pytorch的数据并行方式,是经常使用的单机多卡的并行方式。
工作原理
这种方式只有一个进程,假设我们有一个8卡的机器,然后设置的batchsize是128,那么单卡的batchsize就是32,在训练的时候,他把我们的模型分发到每个GPU上,然后在每个GPU上面对着32张图进行前向和反向传播得到loss,之后将所有的loss汇总到0卡上,对0卡的参数进行更新,然后,非常重要的一点:他再把更新后的模型传给其他7张卡,这个操作是在每个batch跑完都会执行一次的。
可以想象一下,每个batch,GPU之间都要互相传模型,GPU的通信显然成为了一个极大的瓶颈,直接导致GPU在一段时间是闲着的,也就导致了GPU的利用率低的问题。
缺点
- GPU利用率低
- 不支持多机多卡
- 不支持混合精度
使用方法
nn.DataParallel的使