遇到的问题
在使用DDP进行多卡训练时,通过 nvidia-smi
,查看显存占用,发现GPU0占用相较于其他GPU更高,并且每一个进程都在GPU上有占用,当显存较为紧张时,可能导致爆显存。
解决方法
解决方法1
网上找到的大部分结果都只提到了第一种解决办法。
pytorch在load模型时,通过查看pytorch官方文档可以发现,torch.load()
方法,在未指定map_location时,torch.load() uses Python’s unpickling facilities but treats storages, which underlie tensors, specially. They are first deserialized on the CPU and are then moved to the device they were saved from. If this fails (e.g. because the run time system doesn’t have certain devices), an exception is raised. However, storages can be dynamically remapped to an alternative set of devices using the argument.map_location
,可以看出,载入模型时,先被载入到cpu然后移动到对应设备,这里对应设备指的是GPU0,而通过指定参数map_location,可以使用对应设备进行载入。因此,可以修改对应部分代码如下:
# 修改前
checkpoint = torch.load("checkpoint.pth")
model.load_state_dict(checkpoint["state_dict"])
#修改后
checkpoint = torch.load("checkpoint.pth", map_location='cpu') # 使用cpu载入
model.load_state_dict(checkpoint["state_dict"])
解决方法2
然而,通过解决方法1,并未能解决我的问题。通过google,最终找到了解决该问题的方式。可以参考https://discuss.pytorch.org/t/extra-10gb-memory-on-gpu-0-in-ddp-tutorial/118113。通过在模型初始化前添加以下两句,从而解决该问题
torch.cuda.set_device(rank)
torch.cuda.empty_cache()