一、存储/加载模型
可以选择保存整体model网络结构和参数
PATH = 'saved_model.pth'
# 保存整个model:
torch.save(model_0, PATH)
# 加载整个model:
model_1 = torch.load(PATH)
or只保存model参数
PATH = 'saved_model.pth'
# 由于只加载参数,因此需要提前定义网络结构,例如:
class Net(nn.Module):
...
# 只保存参数:
torch.save(model_0.state_dict, PATH)
# 分别加载网络和参数:
model_1 = Net()
model_1.load_state_dict(torch.load(PATH), strict=False) # 只会加载键值相同的参数
二、多GPUs下加载过程出错
使用 torch.nn.DataParallel(model_0, device_ids=[0, 1])
语句后,加载的模型变量会多“module”关键字,例如:
model_0 = torch.load('resnet50.pth')
model_1 = torch.nn.DataParallel(model_0, device_ids=[0, 1])
torch.save(model_1, PATH)
model_2 = torch.load(PATH)
model_3 = model_2.module
调试界面显示:
如果直接加载model_2,则可能会因为多了module关键字而报错:
IndexError: list index out of range
其实就是网络结构和参数对不上了(因为多了module)
可以使用model_3 = model_2.module
语句,去掉module。
注意:
(1) 每次使用 torch.nn.DataParallel(model_0, device_ids=[0, 1])
都会生成一个module关键字,所以用了几次就要去掉几次;
(2) 哪怕 torch.nn.DataParallel(model_0, device_ids=[0,])
也会有module的,如果只想用1个GPU,直接用tensor.cuda()
参考
https://blog.csdn.net/CV_YOU/article/details/86670188
https://blog.csdn.net/qq_37959202/article/details/105104278