1. 保存整个模型(包括结构和参数)
- 保存内容:整个模型,包括模型结构和模型参数。
- 文件大小:较大,因为保存了模型结构和权重。
- 优点:加载时非常方便,不需要重新定义模型结构。
- 缺点:如果模型的代码结构发生了变化(如版本升级或修改了代码),可能会导致加载失败。
保存方法:
import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=True)
torch.save(vgg16, '/path/to/directory/vgg16_origin.pth')
加载方法
vgg16_loaded = torch.load('/path/to/directory/vgg16_origin.pth')
2.只保存模型参数(推荐方式)
- 保存内容:仅保存模型的参数(权重),不保存模型结构。
- 文件大小:较小,因为只保存了模型参数。
- 优点:与保存整个模型相比,文件更小,兼容性更好,适用于部署和分享。
- 缺点:加载时需要手动重新定义模型结构。
保存方法:
torch.save(vgg16.state_dict(), '/path/to/directory/vgg16_state_dict.pth')
读取方法
vgg16 = torchvision.models.vgg16() # 重新定义模型结构
vgg16.load_state_dict(torch.load('/path/to/directory/vgg16_origin.pth'))
3.保存和加载检查点(Checkpoint)
- 保存内容:除了模型参数外,还可以保存训练状态(如当前的 epoch、损失函数、优化器状态等)。
- 优点:适用于长时间训练任务,可以从检查点继续训练。
- 缺点:文件大小可能较大,包含了更多信息。
保存方法:
checkpoint = {
'epoch': epoch,
'model_state_dict': vgg16.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}
torch.save(checkpoint, 'checkpoint.pth')
读取方法:
checkpoint = torch.load('checkpoint.pth')
vgg16.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
4.Torch.jit
torch.jit 是 PyTorch 提供的工具,用于将模型转换为 TorchScript 格式,适合用于部署。
- 保存内容:将模型保存为 TorchScript 格式,包含模型结构和参数。TorchScript 模型可以脱离 Python 运行环境,部署到其他支持 LibTorch(C++ 接口)的平台上,如移动设备、嵌入式设备、Web 服务等。
- 优点:适用于跨平台部署,如移动设备。
- 缺点:使用和调试略有复杂性。
保存方法:
scripted_model = torch.jit.script(vgg16)
scripted_model.save('vgg16_scripted.pth')
读取方法:
loaded_scripted_model = torch.jit.load('vgg16_scripted.pth')