多GPU并行
torch.nn.DataParallel
使用非常简单,基本只需添加一行代码就可扩展到多GPU。
如果想限制GPU使用,可以设置os.environ['CUDA_VISIBLE_DEVICES'] = "0, 2, 4"
,注意程序执行时会对显卡进行重新编号,不一定跟实际完全对应。
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
print("Let's use ", torch.cuda.device_count(), "GPUs.")
model = nn.DataParallel(model)
model.to(device) # 等价于 model = model.to(device)
data = data.to(device) # 注意:数据必须要有赋值,不能写成 data.to(device)
显存不平衡问题:模型是保存在第0卡上的,计算loss的梯度也默认在0卡上。所以第0卡会比其他的卡占用更多显存,尤其是当模型比较大的时候。这样会限制整体batch_size的大小,同时其余卡的显存也没有完全利用。解决方案是BalancedDataParallel
和DistributedDataParallel
。
原理:首先把模型放在第0块卡上,然后通过nn.DataParallel
找到所有可用的显卡并将模型进行复制。运行时将每个batch的数据平均分到不同GPU进行forward计算,将loss汇总到第0卡反向传播,最后将更新后的模型参数再复制到其他GPU中。所以要求batch_size >= GPU数量。
BalancedDataParallel
代码:有大佬改进了transformer-XL的源码(GitHub)
做法是自己实现一个继承自DataParallel的BalancedDataParallel
类,手动调整每个batch数据在多GPU的分配,然后就可以指定第0卡少处理一些数据,从而充分利用每块卡的显存。
用法与DataParallel完全