1.torch.load()未指定map_location:
# 修改前, 默认使用GPU:0载入
checkpoint = torch.load("checkpoint.pth")
model.load_state_dict(checkpoint["state_dict"])
#修改后, 使用cpu载入
checkpoint = torch.load("checkpoint.pth", map_location='cpu')
model.load_state_dict(checkpoint["state_dict"])
2.将下面代码放在训练代码前,确保在进行分布式训练时,每个进程在初始化时都设置了正确的CUDA设备,并清除了CUDA缓存。
# Set the CUDA device based on local_rank
rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(rank)
torch.cuda.empty_cache()