最近在学习mmdetection的分布式数据并行,其中涉及到pytorch的DP和DDP,到网上搜集了很多资料,通过这篇文章来总结一下。
一、并行
随着网络模型越来越大,并行技术越来越必不可少。这篇文章中记录了我们为什么需要并行技术,以及对并行技术做了简要的总结:常见的分布式并行策略。
简而言之,并行技术可以加快训练速度以及解决显存不足的问题。
今天我们主要讨论的数据并行。
二、DataParallel
pytorch提供的DataParallel是实现数据并行最简单的方式,只需要一行代码就可以实现数据并行。
model = nn.DataParallel(model)
这篇文章中清晰的展示了DP的运行流程:单机多卡数据并行-DataParallel(DP)
这篇文章中对DP的代码进行了分析:PyTorch 源码解读之 DP & DDP:模型并行和分布式训练解析
简单来说,DP通过多线程的方式实现数据并行。
在将数据输入到模型之前,需要将数据和模型加载到gpu。
device = torch.device('cuda')
model = model.to(device)
data = data.to(device),
label = label.to( device)
当执行下面代码时,进入DP的forward函数
output = model(data)
scatter负责将模型从主gpu分发到其他gpu,每个gpu上都是完整的模型
replicate负责将数据分发的其他gpu,一个batch的数据会平均分发
parallel_apply负责前向传播,这里会通过多线程进行
gather负责收集所有gpu上的输出
作为最简单的数据并行方式,DP的缺点显而易见,这篇文章中记录了DP的缺点:Pytorch的nn.DataParallel
主要的缺点时负载不均衡的问题,通过上面的流程我们知道,主gpu会对输入进行汇总、计算Loss、对梯度进行汇总、更新权重。
汇总其他gpu的数据、计算过程中的累加、中间结果的产生都会造成主gpu占用的显存高于其他gpu。
另外多线程受限于全局解释器锁主要用来提高IO密集型程序的性能,对计算密集型程序并没有多少帮助。
在文章Pytorch的nn.DataParallel中作者将优化器也使用DP进行了包装,由此引发了计算Loss的问题也很有趣。