Pytorch保存模型

1. 保存整个模型(包括结构和参数)

  • 保存内容:整个模型,包括模型结构和模型参数。
  • 文件大小:较大,因为保存了模型结构和权重。
  • 优点:加载时非常方便,不需要重新定义模型结构。
  • 缺点:如果模型的代码结构发生了变化(如版本升级或修改了代码),可能会导致加载失败。

 保存方法:

import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=True)
torch.save(vgg16, '/path/to/directory/vgg16_origin.pth')

加载方法 

vgg16_loaded = torch.load('/path/to/directory/vgg16_origin.pth')

2.只保存模型参数(推荐方式) 

  • 保存内容:仅保存模型的参数(权重),不保存模型结构。
  • 文件大小:较小,因为只保存了模型参数。
  • 优点:与保存整个模型相比,文件更小,兼容性更好,适用于部署和分享。
  • 缺点:加载时需要手动重新定义模型结构。

保存方法:

torch.save(vgg16.state_dict(), '/path/to/directory/vgg16_state_dict.pth')

读取方法

vgg16 = torchvision.models.vgg16()  # 重新定义模型结构
vgg16.load_state_dict(torch.load('/path/to/directory/vgg16_origin.pth'))

3.保存和加载检查点(Checkpoint) 

  • 保存内容:除了模型参数外,还可以保存训练状态(如当前的 epoch、损失函数、优化器状态等)。
  • 优点:适用于长时间训练任务,可以从检查点继续训练。
  • 缺点:文件大小可能较大,包含了更多信息。

保存方法:

checkpoint = {
    'epoch': epoch,
    'model_state_dict': vgg16.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
}

torch.save(checkpoint, 'checkpoint.pth')

读取方法:

checkpoint = torch.load('checkpoint.pth')
vgg16.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

4.Torch.jit

        torch.jit 是 PyTorch 提供的工具,用于将模型转换为 TorchScript 格式,适合用于部署。 

  • 保存内容:将模型保存为 TorchScript 格式,包含模型结构和参数。TorchScript 模型可以脱离 Python 运行环境,部署到其他支持 LibTorch(C++ 接口)的平台上,如移动设备、嵌入式设备、Web 服务等。
  • 优点:适用于跨平台部署,如移动设备。
  • 缺点:使用和调试略有复杂性。

    保存方法:

scripted_model = torch.jit.script(vgg16)
scripted_model.save('vgg16_scripted.pth')

    读取方法:

 

loaded_scripted_model = torch.jit.load('vgg16_scripted.pth')

 

         

 

 

 

 

  • 3
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值