load_net[“params”] 报keyerror
加载模型后查看对应参数是什么
model2 = torch.load(m1_path + "xxx.pth")
print(model1.keys())
- 若输出如下:
已经有相应参数不需要执行 load_net[“params”] - 若输出如下
则需要load_net[“params”]
pytorch 加载模型的两种方式
- 直接加载模型和参数
# 保存和加载整个模型 torch.save(model_object, 'resnet.pth') model = torch.load('resnet.pth')
- 分别加载网络的结构和参数
# 将my_resnet模型储存为my_resnet.pth torch.save(my_resnet.state_dict(), "my_resnet.pth") # 加载resnet,模型存放在my_resnet.pth my_resnet.load_state_dict(torch.load("my_resnet.pth"))