使用DistributedDataParallel(DDP)时遇到额外进程导致GPU0显存不均衡的问题

遇到的问题

在使用DDP进行多卡训练时,通过 nvidia-smi,查看显存占用,发现GPU0占用相较于其他GPU更高,并且每一个进程都在GPU上有占用,当显存较为紧张时,可能导致爆显存。
图片来源知乎

解决方法

解决方法1

网上找到的大部分结果都只提到了第一种解决办法。
pytorch在load模型时,通过查看pytorch官方文档可以发现,torch.load()方法,在未指定map_location时,torch.load() uses Python’s unpickling facilities but treats storages, which underlie tensors, specially. They are first deserialized on the CPU and are then moved to the device they were saved from. If this fails (e.g. because the run time system doesn’t have certain devices), an exception is raised. However, storages can be dynamically remapped to an alternative set of devices using the argument.map_location,可以看出,载入模型时,先被载入到cpu然后移动到对应设备,这里对应设备指的是GPU0,而通过指定参数map_location,可以使用对应设备进行载入。因此,可以修改对应部分代码如下:

# 修改前
checkpoint = torch.load("checkpoint.pth")
model.load_state_dict(checkpoint["state_dict"])

#修改后
checkpoint = torch.load("checkpoint.pth", map_location='cpu') # 使用cpu载入
model.load_state_dict(checkpoint["state_dict"])

解决方法2

然而,通过解决方法1,并未能解决我的问题。通过google,最终找到了解决该问题的方式。可以参考https://discuss.pytorch.org/t/extra-10gb-memory-on-gpu-0-in-ddp-tutorial/118113。通过在模型初始化前添加以下两句,从而解决该问题

torch.cuda.set_device(rank)
torch.cuda.empty_cache()
  • 5
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值