saveloadrun_tutorial 介绍了在Pytorch中如何保存和加载训练后的模型(checkpoint),以及如何运行(或者说测试)test一个已经训练好的模型。
教程主要包括以下内容:
-
保存模型(checkpoint):介绍如何在Pytorch中保存训练后得到的模型参数,以及如何将这些参数保存到文件中。
-
加载模型(checkpoint):你会学到如何从保存的文件中加载模型参数,以便于从上一次训练停止的地方继续进行训练,或者进行模型测试。
-
测试模型:这部分介绍了如何使用加载的训练好的模型来进行模型测试。同时,还介绍了如何在测试过程中打印模型的输出结果,并对结果进行解释。
其他部分包括了一些常用的Pytorch函数和技巧,如如何在模型训练时进行batch数据加载、使用GPU加速模型训练以及在Pytorch中使用tensorboard。
模型的保存和加载
在本节中,我们将介绍如何使用保存、加载和执行模型预测来持久化模型状态。
import torch
import torchvision.models as models # 导入 torchvision 库中的预训练模型
保存和加载模型权重
PyTorch 模型使用内部状态字典 state_dict
存储学习到的参数。这些参数可以通过 torch.save
方法进行持久化保存:
model = models.vgg16(weights='IMAGENET1K_V1') # 加载预训练模型 VGG16
torch.save(model.state_dict(), 'model_weights.pth') # 保存模型参数到文件 'model_weights.pth'
要加载模型的权重,需要先创建一个与原模型相同的实例,然后使用 load_state_dict()
方法加载参数。
model = models.vgg16() # 创建一个未训练的 VGG16 模型
model.load_state_dict(torch.load('model_weights.pth')) # 加载存储的模型参数
model.eval() # 将模型设置为评估模式
注意:
- 在推理将丢弃层和批处理规范化层设置为评估模式之前,请务必调用
model.eval()
方法。如果不这样做,将产生不一致的推理结果。
保存和加载带有形状的模型
当我们加载模型权重时,我们需要先实例化模型类,因为该类定义了网络的结构。如果我们想要与模型一起保存该类的结构,我们可以将模型 model
本身(而不是其状态字典 model.state_dict()
)传递给保存函数。
# 保存完整的模型(包含模型结构和权重信息)到文件'model.pth'
torch.save(model, 'model.pth')
然后我们可以像这样加载模型:
model = torch.load('model.pth')
注意:
- 这种方法使用 Python pickle 模块,因此它依赖于加载模型时可用的实际类定义。