本文主要介绍如何加载和保存 PyTorch 的模型的区别与联系。这里主要有两个核心函数:
torch.save
:把序列化的对象保存到硬盘。它利用了 Python 的pickle
来实现序列化。模型、张量以及字典都可以用该函数进行保存;torch.load
:采用pickle
将反序列化的对象从存储中加载进来。
-
保存和加载整个模型:
- 保存模型:使用
torch.save()
函数将整个模型保存到文件中。可以指定文件路径和名称以及要保存的对象。torch.save(model, 'model.pth')
- 加载模型:使用
torch.load()
函数加载保存的模型。可以通过指定文件路径和名称来加载模型。model = torch.load('model.pth')
- 保存模型:使用
-
保存和加载模型参数:
- 保存模型参数:使用
state_dict()
方法获取模型的参数字典,然后使用torch.save()
函数将参数字典保存到文件中。torch.save(model.state_dict(), 'model_weights.pth')
- 加载模型参数:首先需要创建一个与原始模型结构相同的实例,然后使用
load_state_dict()
方法加载保存的参数字典。m
- 保存模型参数:使用