#保存和加载模型参数
torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl')
# 仅保存和加载模型参数(推荐使用)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))
相关的方法如下:
torch.nn.Module.state_dict()
torch.nn.Module.load_state_dict()
torch.save()
torch.load()
需要注意的是,文件保存在自定义当前路径下,使用
import os
os.getcwd()
可以查看当前路径。