Pytorch(8)-保存并加载模型
保存并加载模型
在本节中,我们将研究如何通过保存,加载和运行模型预测来保持模型状态
import torch
import torch.onnx as onnx
import torchvision.models as models
保存和加载模型权重
PyTorch模型将学习到的参数存储在称为的内部状态字典中state_dict。这些可以通过以下torch.save 方法持久化:
model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')
要加载模型权重,您需要首先创建相同模型的实例,然后使用load_state_dict()方法加载参数。
model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default wei