模型训练后,需要保存到文件,以供测试和部署;或,继续之前的训练状态.
1. Best Practices
主要有两种模型序列化保存和加载恢复的方法.
1.1 方法 M1 - 推荐
只保存和加载恢复模型参数(model parameters):import torch
# 保存
torch.save(the_model.state_dict(), PATH)
# 恢复
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
# 该方法需要自己另导入模型的网络结构信息.
1.2 方法 M2
同时保存模型的参数和网络结构信息:import torch
# 保存
torch.save(the_model, PATH)
# 恢复
the_model = torch.load(PATH)
# 该方法保存的数据绑定着特定的 classes 和所用的确切目录结构. ‘
# 因此,再加载后经过许多重构后,可能会被打乱.
2. Stackoverflow 回答
根据应用场景,选择模型保存和加载恢复方法.
场景 C1 - 模型保存自用于推断
自己保存模型,自己恢复模型,然后,修改模型为 evaluation 模式.
这是因为,默认情况时,网络模型训练时往往有 BatchNorm 和 Dropout 网络层.# 模型保存
torch.save(model.state_dic