当提到保存和加载模型时,有三个核心功能需要熟悉:
- torch.save:将序列化的对象保存到disk。这个函数使用Python的pickle实用程序进行序列化。使用这个函数可以保存各种对象的模型、张量和字典。
- torch.load:使用pickle unpickle工具将pickle的对象文件反序列化为内存。
- torch.nn.Module.load_state_dict:使用反序列化状态字典加载model’s参数字典。
一、模型保存与调用方式一:只保存模型参数
1、模型保存
model = TheModelClass(*args, **kwargs)
# ------------- 模型训练: 开始 -------------
......
# ------------- 模型训练: 结束 -------------
PATH = r'.\saved_model\model_state_dict_step_01.pt'
torch.save(model.state_dict(), PATH)
在保存模型进行推理时,只需要保存训练过的模型的学习参数即可。
一个常见的PyTorch约定是使用.pt或.pth文件扩展名保存模型。
2、模型加载
# 重构模型结构(与保存的模型结构要完全一致)
model = TheModelClass(*args, **kwargs)
# 根据模型结构,调用存储的模型参数
model.load_state_dict(torch.load(PATH))
注意,load_state_dict()
函数接受一个dictionary对象,而不是保存对象的路径。这意味着您必须在将保存的state_dict传至load_state_dict()函数之前反序列化它(用torch.load(PATH)
进行反序列化)。
加载后的模型如果用于验证或测试,则先调用model.eval()
,以便在运行推断之前将dropout和batch规范化层设置为评估模式。如果不这样做,将会产生不一致的推断结果。
二、模型保存与调用方式二:保存整个模型(结构+参数)
1、模型保存
model = TheModelClass(*args,<