pytorch-lightning训练的结果保存为xxx.pt
问题描述:
模型是基于pytorch-lightning写的,但是由于pl的forward函数是forward(self, *args, **kwargs),直接用torch.jit.script(model)来保存,会报错。
保存为.pt是为了在C++代码中执行推理,关于这个内容,很多人写过教程了。但是目前关于我遇到的这个问题,还没有人给出系统的回答。
问题代码:
直接加载模型,然后保存:
import torch
# 加载PL NN模型
model = ABC(xxxx)
# 加载权重
model.load_state_dict(ckptfile)
# torch.jit转换模型
f = torch.jit.script(model)
# 保存为xxx.pt文件
f.save(pt_file)
报错如下:
torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
File "xxx/site-packages/pytorch_lightning/core/lightning.py", line 458
def forward(self, *args, **kwargs):
~~~~~~~ <--- HERE
r"""
Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define
解决方案
先看看GPT怎么说的,以下内容来自 bing,我的方案在最后。
GPT方案
-
Q:
- 报错,torch.jit在保存模型时,要求forward()函数中不能有变量,但是pytorch lightning的forward()函数采样了*args, **kwargs的输入,与torch.jit的要求不符。
-
A:
-
您好,您提到的问题是一个常见的问题。在使用 PyTorch Lightning 时,
forward
函数通常会采用*args
和**kwargs
作为输入参数,以便更灵活地处理不同类型的输入。然而,当使用torch.jit.script
或torch.jit.trace
将模型转换为 TorchScript 格式时,这些函数可能无法处理这种灵活性。一种解决方案是在将模型转换为 TorchScript 格式之前,重写
forward
函数,使其具有固定数量的输入参数。例如,如果您的模型始终接受两个输入张量x
和y
,则可以将forward
函数重写为:def forward(self, x, y): # 模型前向传播逻辑
这样,在使用
torch.jit.script
或torch.jit.trace
将模型转换为 TorchScript 格式时,就不会出现问题了。
-
-
Q:
- 是否有其他方案?
-
A:
-
除了使用
torch.jit.script
或torch.jit.trace
将 PyTorch Lightning 模型转换为 TorchScript 格式外,您还可以考虑使用 ONNX(开放神经网络交换)格式。ONNX 是一种用于表示深度学习模型的开放标准,它允许您在不同的深度学习框架之间转换模型1。要将 PyTorch Lightning 模型转换为 ONNX 格式,您可以使用 PyTorch 的
torch.onnx.export
函数。下面是一个简单的示例,展示了如何将 PyTorch Lightning 模型转换为 ONNX 格式:import torch from pytorch_lightning import LightningModule class MyModel(LightningModule): def __init__(self): super().__init__() self.layer = torch.nn.Linear(10, 1) def forward(self, x): return self.layer(x) model = MyModel() input_sample = torch.randn(1, 10) torch.onnx.export(model, input_sample, 'model.onnx')
在上面的示例中,我们首先定义了一个简单的 PyTorch Lightning 模型
MyModel
,然后使用torch.onnx.export
函数将其转换为 ONNX 格式并保存到文件model.onnx
中。在 C++ 中,您可以使用 ONNX 运行时库来加载并推理保存的 ONNX 模型。下面是一个简单的示例,展示了如何在 C++ 中使用 ONNX 运行时库加载并推理 ONNX 模型:
#include <onnxruntime_cxx_api.h> int main() { // 创建 ONNX 运行时会话 Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "test"}; Ort::SessionOptions session_options; session_options.SetIntraOpNumThreads(1); Ort::Session session{env, "model.onnx", session_options}; // 获取模型输入和输出信息 Ort::AllocatorWithDefaultOptions allocator; size_t num_input_nodes = session.GetInputCount(); std::vector<const char*> input_node_names(num_input_nodes); std::vector<int64_t> input_node_dims; for (int i = 0; i < num_input_nodes; i++) { // 获取输入节点名称和维度 char* input_name = session.GetInputName(i, allocator); input_node_names[i] = input_name; Ort::TypeInfo type_info = session.GetInputTypeInfo(i); auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); input_node_dims = tensor_info.GetShape(); } // 创建输入张量 std::vector<float> input_tensor_values(10); std::vector<Ort::Value> input_tensors; input_tensors.push_back(Ort::Value::CreateTensor<float>(allocator.GetInfo(), input_tensor_values.data(), 10, input_node_dims.data(), 2)); // 进行推理 auto output_tensors = session.Run(Ort::RunOptions{nullptr}, input_node_names.data(), input_tensors.data(), num_input_nodes, nullptr, 0); // 获取输出结果 float* floatarr = output_tensors[0].GetTensorMutableData<float>(); std::cout << *floatarr << std::endl; return 0; }
-
-
Q:
- 在C++中用LibPytorch和使用onnxruntime各自的优缺点是什么?
-
A:
-
在 C++ 中使用 LibTorch 和使用 ONNX 运行时(onnxruntime)都可以实现 PyTorch 模型的推理。这两种方法各有优缺点。
LibTorch 是 PyTorch 的 C++ 前端,它提供了与 PyTorch 相同的 API,因此可以直接在 C++ 中使用 PyTorch 模型进行推理。这意味着您可以直接加载保存的 PyTorch 模型,而无需进行任何转换。此外,由于 LibTorch 是 PyTorch 的一部分,因此它能够支持 PyTorch 的所有功能和最新特性。
相比之下,ONNX 运行时是一个用于执行 ONNX(开放神经网络交换)模型的推理引擎。要在 C++ 中使用 ONNX 运行时进行推理,您需要先将 PyTorch 模型转换为 ONNX 格式。ONNX 运行时具有出色的性能和可移植性,可以在多种平台和设备上运行,包括云端、边缘设备、Web 和移动设备1。
总之,如果您希望直接在 C++ 中使用 PyTorch 模型,并且需要支持 PyTorch 的所有功能和最新特性,则可以考虑使用 LibTorch。如果您希望获得更好的性能和可移植性,并且不介意将模型转换为 ONNX 格式,则可以考虑使用 ONNX 运行时。希望这些信息能够帮助您了解这两种方法的优缺点。
-
-
Q:
- pytorch lightning模型可以转换成torch模型吗?
-
A:
-
是的,PyTorch Lightning 模型可以转换为普通的 PyTorch 模型。PyTorch Lightning 是一个用于构建、训练和调优 PyTorch 模型的高层库,它在 PyTorch 的基础上提供了更简洁、更易用的 API。由于 PyTorch Lightning 模型本质上仍然是 PyTorch 模型,因此您可以直接将其转换为普通的 PyTorch 模型。
要将 PyTorch Lightning 模型转换为普通的 PyTorch 模型,您只需要获取 PyTorch Lightning 模型的
model
属性即可。例如:import pytorch_lightning as pl import torch class MyLightningModel(pl.LightningModule): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(28 * 28, 10) def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) # 创建 PyTorch Lightning 模型 lightning_model = MyLightningModel() # 转换为普通的 PyTorch 模型 pytorch_model = lightning_model.model
在上面的示例中,我们首先定义了一个简单的 PyTorch Lightning 模型
MyLightningModel
,然后使用lightning_model.model
获取了对应的普通 PyTorch 模型。
感叹一下,在大模型的加持下,我们只要学会如何提问就好了,GPT真的会帮我们找到方法。
-
我的方案
根据GPT的提示,将PL转为正常的torch
其实就多了一句话:
# 加载PL NN模型
model = ABC(xxxx)
# 加载权重
model.load_state_dict(ckptfile)
model_torch = model.model
# torch.jit转换模型
f = torch.jit.script(model_torch)
# 保存为xxx.pt文件
f.save(pt_file)
搞定!
如果帮助到你,给个点赞+收藏+关注