- 使用pytorch加载模型的时候,通常语句写法
model = torch.load(model_file)
- 只有CPU时,修改为:
model = torch.load(model_path, map_location='cpu')
- 只有一个GPU时,修改为:
model = torch.load(model_path, map_location='cuda:0')
- 多个GPU时,比如2个GPU, 修改为:
多个GPU,以此类推model = torch.load(model_path, map_location={'cuda:1': 'cuda:0'})