这里介绍只保存加载模型的参数的方法,因为速度快,占内存小。
保存语法:
torch.save(model.state_dict(),'model.pth')
加载模型的语法:
因为只保存了模型的参数,所以需要先定义一个网络对象,然后再加载模型的参数。
例如:
model = ClassNet()
将模型参数加载到新模型中,torch.load返回的时一个OrderedDict,model.state_dict()把模型的所有参数都以OrderdeDict的形式保存了下来。
{'epoch': 1, 'model_name': 'resnet', 'state_dict': OrderedDict([('backbone.conv1.weight', tensor([[[[-9.6635e-03, -5.8054e-03, -1.7499e-03, ..., 5.6849e-02,
1.7084e-02, -1.2774e-02],
[ 1.1954e-02, 1.0023e-02, -1.0967e-01, ..., -2.7083e-01,
-1.2892e-01, 3.8908e-03],
[-6.0705e-03, 5.9568e-02, 2.9577e-01, ..., 5.2005e-01,
2.5649e-01, 6.3826e-02],
state_dict = torch.load('model.pth')
model.load_state_dict(state_dict)
这里附上一份自己的torch分类代码中的模型的保存和加