🤵 Author :Horizon John
✨ 编程技巧篇:各种操作小结
🎇 机器视觉篇:会变魔术 OpenCV
💥 深度学习篇:简单入门 PyTorch
🏆 神经网络篇:经典网络模型
💻 算法篇:再忙也别忘了 LeetCode
模型保存
torch.save()
1)保存全部
保存整个模型
torch.save(model, path)
2)保存部分
只保存模型训练的参数(不包括网络结构)
torch.save(model.state_dict(), path)
模型加载
torch.load()
1)加载全部
加载整个模型
model = torch.load(path)
2)加载部分
只加载模型训练的参数(不包括网络结构)
model = Net() # 网络
model = model.to(device)
checkpoint = torch.load(path)
model.load_state_dict(checkpoint)
加载模型GPU和CPU转换
使用 map_location
参数对 GPU
和 CPU
进行转化
1)GPU 转 CPU
model = torch.load(PATH, map_location='cpu')
2)CPU 转 GPU
model = torch.load(PATH, map_location=lambda storage, loc: storage.cuda(0))
3)GPU之间转换
torch.load(PATH, map_location={'cuda:0':'cuda:1'}) # GPU0转到GPU1