PyTorch模型的保存与加载

PyTorch模型的保存和加载

1:pytorch保存和加载模型的方法

1.1 仅保存和加载模型参数

保存模型参数
这种方法只是保存模型的参数,因此加载模型时,应提前将模型准备好,然后载入模型参数

import torch
import torch.nn as nn
 
model = nn.Sequential(nn.Linear(128, 16), nn.ReLU(), nn.Linear(16, 1))
 
# 保存整个模型的参数
torch.save(model.state_dict(), 'sample_model.pt')

加载模型

import torch
import torch.nn as nn
 
# 下载模型参数 并放到模型中
loaded_model = nn.Sequential(nn.Linear(128, 16), nn.ReLU(), nn.Linear(16, 1))
loaded_model.load_state_dict(torch.load('sample_model.pt'))
print(loaded_model)

显示如下

Sequential(
  (0): Linear(in_features=128, out_features=16, bias=True)
  (1): ReLU()
  (2): Linear(in_features=16, out_features=1, bias=True)
)

1.2 保存和加载整个模型

保存整个模型
这个方法是保存整个模型和参数

import torch
import torch.nn as nn
 
net = nn.Sequential(nn.Linear(128, 16), nn.ReLU(), nn.Linear(16, 1))
 
# 保存整个模型,包含模型结构和参数
torch.save(net, 'sample_model.pt')

加载整个模型

import torch
import torch.nn as nn
 
# 加载整个模型,包含模型结构和参数
loaded_model = torch.load('sample_model.pt')
print(loaded_model)

显示如下

Sequential(
  (0): Linear(in_features=128, out_features=16, bias=True)
  (1): ReLU()
  (2): Linear(in_features=16, out_features=1, bias=True)
)

1.3 导出和加载ONNX格式模型

保存模型

import torch
import torch.nn as nn
 
model = nn.Sequential(nn.Linear(128, 16), nn.ReLU(), nn.Linear(16, 1))
 
input_sample = torch.randn(16, 128)  # 提供一个输入样本作为示例
torch.onnx.export(model, input_sample, 'sample_model.onnx')

加载模型

import torch
import torch.nn as nn
import onnx
import onnxruntime
 
loaded_model = onnx.load('sample_model.onnx')
session = onnxruntime.InferenceSession('sample_model.onnx')
print(session)

2 模型加载和保存使用到的函数

2.1 保存模型的函数 torch.save

将对象序列化保存到磁盘中,该方法原理是基于python中的pickle来序列化,各种Models,tensors,dictionaries 都可以使用该方法保存。保存的模型文件名可以是.pth, .pt, .pkl。

def save(
    obj: object,
    f: FILE_LIKE,
    pickle_module: Any = pickle,
    pickle_protocol: int = DEFAULT_PROTOCOL,
    _use_new_zipfile_serialization: bool = True
) -> None:

obj:保存的对象
f:一个类似文件的对象(必须实现写入和刷新)或字符串或操作系统。包含文件名的类似路径对象
pickle_module:用于挑选元数据和对象的模块
pickle_protocol:可以指定以覆盖默认协议

备注:关于模型的后缀.pt、.pth、.pkl它们并不存在格式上的区别,只是后缀名不同而已。 torch.save()语句保存出来的模型文件没有什么不同。

2.2 加载模型的函数torch.load

def load(
    f: FILE_LIKE,
    map_location: MAP_LOCATION = None,
    pickle_module: Any = None,
    *,
    weights_only: bool = False,
    **pickle_load_args: Any
) -> Any:

f:类文件对象 (返回文件描述符)或一个保存文件名的字符串
map_location:一个函数或字典规定如何映射存储设备,torch.device对象
pickle_module:用于 unpickling 元数据和对象的模块 (必须匹配序列化文件时的 pickle_module )

2.3 加载模型参数torch.nn.Module.load_state_dict

序列化 (Serialization)是将对象的状态信息转换为可以存储或传输的形式的过程。 在序列化期间,对象将其当前状态写入到临时或持久性存储区。以后,可以通过从存储区中读取或反序列化对象的状态,重新创建该对象。

def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',
                        strict: bool = True):

state_dict:保存 parameters 和 persistent buffers 的字典
strict:可选,bool型。state_dict 中的 key 是否和 model.state_dict() 返回的 key 一致。

2.4 状态字典 state_dict

函数作用是“获取优化器当前状态信息字典”,在神经网络中模型上训练出来的模型参数,也就是权重和偏置值。在Pytorch中,定义网络模型是通过继承torch.nn.Module来实现的。其网络模型中包含可学习的参数(weights, bias, 和一些登记的缓存如batchnorm’s running_mean 等)。模型内部的可学习参数可通过两种方式进行调用:
通过model.parameters()这个生成器来访问所有参数。
通过model.state_dict()来为每一层和它的参数建立一个映射关系并存储在字典中,其键值由每个网络层和其对应的参数张量构成

def state_dict(self, destination=None, prefix='', keep_vars=False):

除模型外,优化器对象(torch.optim)同样也有一个状态字典,包含的优化器状态信息以及使用的超参数。由于状态字典属于Python 字典,因此对 PyTorch 模型和优化器的保存、更新、替换、恢复等操作都比较便捷。

3 指定map_location 加载参数

采用仅加载模型参数的方式,指定设备类型进行模型加载,代码如下:

model_path = '/opt/sample_model.pth'
 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
map_location = torch.device(device)
 
model.load_state_dict(torch.load(self.model_path, map_location=self.map_location))
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值