pytroch加载gpu训练的DataParallel模型
加载.pkl模型到cpu上
model = torch.load('xxx.pkl')
model.cpu()
报错,cpu环境下不能直接导入gpu训练的DataParallel模型,找了一圈,有建议遍历网络把module去掉的。找到一个简单方法,在gpu上把模型转化掉
model = torch.load('xxx.pkl')
real_model = model.module
torch.save(real_model, 'xxxx.pkl')
new_model = torch.load('xxxx.pkl')
new_model.cpu()
解决。
参考博客:
1、https://blog.csdn.net/qq_41895190/article/details/103350508