深度学习模型保存模型参数的方法有两种:
1.保存整个网络(模型结构+模型参数):
# 保存整个模型和参数
torch.save(model_object, 'convit_tiny.pth')
# 对应的加载模型代码为
model = torch.load('convit_tiny.pth')
print(model)
此时print的是整个网络的模型结构;
若要加载模型的参数:
model = torch.load('convit_tiny.pth')
args = model.state_dict()
print(args)
此时输出的是模型的训练参数:
2.直接保存网络的模型参数:
# 将my_resnet模型储存为my_resnet.pth,此时保存的仅仅是模型的参数
torch.save(model.state_dict(), "convit.pth")
# 直接加载参数
args = torch.load("convit.pth")
# 若要加载模型则先需要初始化之前所定义的网络
new_model = Net()
# 再使用load_state_dict方法将权重加载进网络
# 注意:model.state_dict()其实返回的是一个OrderDict,存储了网络结构的名字和对应的参数;而这里是导入参数因此用的是model.load_state_dict()而不是model.state_dict()
new_model.load_state_dict(torch.load('convit.pth'))