参考博客:添加链接描述
- 错误代码
----之前训练了一个网络模型,保存了网络训练参数,代码如下
torch.save(module1.state_dict(),"module1_{}.pth".format(i))
- 错误原因:
----在进行模型测试的时候,使用如下代码对训练模型直接进行加载:
model=torch.load("Pytorch/module1_9.pth")
- 改正方法
----先载入网络结构,再导入网络的参数( Module1是我的网络模型)
model = Module1() # 导入网络结构
model.load_state_dict(torch.load("Pytorch/module1_9.pth")) # 导入网络的参数
- 另外
----保存模型的两种方式:保存整个训练模型、只保存模型参数
# 保存整个模型
torch.save(module,"module_{}.pth".format(i))
# 只保存模型参数(官方推荐,并且占内存小)
torch.save(module.state_dict(),"module_{}.pth".format(i))