保存整个模型,包括模型结构+参数
torch.save(self.model.cpu(), output_path)
相应的加载是
model = torch.load(output_path)
缺点:这种方式保存的模型只能在相同的环境中使用,因为它依赖于模型定义的代码。如果您在不同的环境中使用这种方式保存的模型,可能会出现错误。(相同的环境指的是运行代码的环境,包括操作系统、Python 版本、PyTorch 版本以及其他依赖库的版本都相同。)
保存模型的参数
torch.save(self.model.state_dict(), output_path)
相应的加载是
state_dict = torch.load(save_path)
bert_model = Architecture() # 实例化模型类
bert_model.load_state_dict(state_dict)
优点:想要在不同的环境中使用保存的模型,建议使用 state_dict
来保存和加载模型的参数。这样,您只需要在新环境中定义相同的模型结构,然后使用 load_state_dict
函数加载保存的参数即可。