Pytorch 并行训练(DP, DDP)的原理和应用
1. 前言
并行训练可以分为数据并行和模型并行。
-
模型并行
模型并行主要应用于模型相比显存来说更大,一块 device 无法加载的场景,通过把模型切割为几个部分,分别加载到不同的 device 上。比如早期的 AlexNet,当时限于显卡,模型就是分别加载在两块显卡上的。
-
数据并行
这个是日常会应用的比较多的情况。每一个 device 上会加载一份模型,然后把数据分发到每个 device 并行进行计算,加快训练速度。
如果要再细分,又可以分为单机多卡,多机多卡。这里主要讨论数据并行的单机多卡的情况。
2. Pytorch 并行训练
常用的 API 有两个:
- torch.nn.DataParallel(DP)
- torch.nn.DistributedDataParallel(DDP)
DP 相比 DDP 使用起来更友好(代码少),但是 DDP 支持多机多卡,训练速度更快,而且负载相对要均衡一些。所以优先选用 DDP 吧。
2.1 训练模型的过程
在开始怎么调用并行的接口之前,了解并行的过程是有必要的。首先来看下模型训练的过程。
2.2 DP
2.2.1 DP 的计算过程
DP 并行的具体过程可以参考下图两幅图。
上图清晰的表达了 torch.nn.DataParallel 的计算过程。
- 将 inputs 从主 GPU 分发到所有 GPU 上
- 将 model 从主 GPU 分发到所有 GPU 上
- 每个 GPU 分别独立进行前向传播,得到 outputs
- 将每个 GPU 的 outputs 发回主 GPU
- 在主 GPU 上,通过 loss function 计算出 loss,对 loss function 求导,求出损失梯度
- 计算得到的梯度分发到所有 GPU 上
- 反向传播计算参数梯度