试验2: DistributedDataParallel是如何做同步的?
- 在开始试验之前我们先说明
DataParallel
,当我们使用DataParallel
去做分布式训练时,假设我们使用四块显卡去做训练,数据的batch_size
设置为8
,则程序启动时只启动一个进程,每块卡会分配batch_size=2
的资源进行forward
操作,当4快卡的forward
操作做完之后,主进程会收集所有显卡的结果进行loss
运算和梯度回传以及参数更新,这些都在主进程中完成,也就是说主进程看到看到的forward
运算的结果是batch_size=8
的。
- 当我们用
DistributedDataParallel
去做分布式训练时,假设我们使用4块显卡训练,总的batch_size
设置为8
,程序启动时会同时启动四个进程,每个进程会负责一块显卡,每块显卡对batch_size=2
的数据进行forward
操作,在每个进程中,程序都会进行的loss的运算、梯度回传以及参数更新,与DataParallel
的区别是,每个进程都会进行loss的计算、梯度回传以及参数更新。这里抛出我们的问题,既然每个进程都会进行loss计算与梯度回传是怎么保证模型训练的同步的呢?
OK, 下面开始我们的试验,不看代码就靠猜。。。
试验用到的代码
- 数据类
datasets.py
: 这个数据类随机生成224x224
大小的图像和其对应的随机标签0-8
class RandomClsDS(Dataset):
def __init__(self):
pass
def __len__(self):
return 10000
def __getitem__(self, item):
image = torch.randn(3,224, 224)
label = np.random.randint(0,9)
return image, label
- 训练类
train.py
import os
import