一、DataParalle缺点
Pytorch单机多卡(GPU)运行的基本方法是使用torch.DataParlle()函数,具体操作参考:
其主要原理:假设有四个GPU,batch_size=64,input_dim为输入数据特征维度。nn.DataParallel() 将随机样本集(64, input_dim)分成四份输入到每个GPU。每个GPU处理(16, input_dim)的数据集(前向传播计算)。然后第一个GPU合并四份输出数据,并计算Loss(反向传播计算)。因此第一个GPU计算量大,负载不均衡。
官方也推荐DistributedDataParallel方法。
二、DistributedDataParallel 原理
待续ing
参考:
1、Distributed communication package - torch.distributed — PyTorch 1.11.0 documentation
5、