pytorch-lightning训练的结果保存为xxx.pt

pytorch-lightning训练的结果保存为xxx.pt

问题描述:

模型是基于pytorch-lightning写的,但是由于pl的forward函数是forward(self, *args, **kwargs),直接用torch.jit.script(model)来保存,会报错。

保存为.pt是为了在C++代码中执行推理,关于这个内容,很多人写过教程了。但是目前关于我遇到的这个问题,还没有人给出系统的回答。

问题代码:

直接加载模型,然后保存:

import torch
# 加载PL NN模型
model = ABC(xxxx)
# 加载权重
model.load_state_dict(ckptfile)
# torch.jit转换模型
f = torch.jit.script(model)
# 保存为xxx.pt文件
f.save(pt_file)

报错如下:

torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
File "xxx/site-packages/pytorch_lightning/core/lightning.py", line 458
def forward(self, *args, **kwargs):
~~~~~~~ <--- HERE
r"""
Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define

解决方案

先看看GPT怎么说的,以下内容来自 bing,我的方案在最后。

GPT方案

  • Q:

    • 报错,torch.jit在保存模型时,要求forward()函数中不能有变量,但是pytorch lightning的forward()函数采样了*args, **kwargs的输入,与torch.jit的要求不符。
  • A:

    • 您好,您提到的问题是一个常见的问题。在使用 PyTorch Lightning 时,forward 函数通常会采用 *args**kwargs 作为输入参数,以便更灵活地处理不同类型的输入。然而,当使用 torch.jit.scripttorch.jit.trace 将模型转换为 TorchScript 格式时,这些函数可能无法处理这种灵活性。

      一种解决方案是在将模型转换为 TorchScript 格式之前,重写 forward 函数,使其具有固定数量的输入参数。例如,如果您的模型始终接受两个输入张量 xy,则可以将 forward 函数重写为:

      def forward(self, x, y):
          # 模型前向传播逻辑
      

      这样,在使用 torch.jit.scripttorch.jit.trace 将模型转换为 TorchScript 格式时,就不会出现问题了。

  • Q:

    • 是否有其他方案?
  • A:

    • 除了使用 torch.jit.scripttorch.jit.trace 将 PyTorch Lightning 模型转换为 TorchScript 格式外,您还可以考虑使用 ONNX(开放神经网络交换)格式。ONNX 是一种用于表示深度学习模型的开放标准,它允许您在不同的深度学习框架之间转换模型1

      要将 PyTorch Lightning 模型转换为 ONNX 格式,您可以使用 PyTorch 的 torch.onnx.export 函数。下面是一个简单的示例,展示了如何将 PyTorch Lightning 模型转换为 ONNX 格式:

      import torch
      from pytorch_lightning import LightningModule
      
      class MyModel(LightningModule):
          def __init__(self):
              super().__init__()
              self.layer = torch.nn.Linear(10, 1)
      
          def forward(self, x):
              return self.layer(x)
      
      model = MyModel()
      input_sample = torch.randn(1, 10)
      torch.onnx.export(model, input_sample, 'model.onnx')
      

      在上面的示例中,我们首先定义了一个简单的 PyTorch Lightning 模型 MyModel,然后使用 torch.onnx.export 函数将其转换为 ONNX 格式并保存到文件 model.onnx 中。

      在 C++ 中,您可以使用 ONNX 运行时库来加载并推理保存的 ONNX 模型。下面是一个简单的示例,展示了如何在 C++ 中使用 ONNX 运行时库加载并推理 ONNX 模型:

      #include <onnxruntime_cxx_api.h>
      
      int main() {
          // 创建 ONNX 运行时会话
          Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "test"};
          Ort::SessionOptions session_options;
          session_options.SetIntraOpNumThreads(1);
          Ort::Session session{env, "model.onnx", session_options};
      
          // 获取模型输入和输出信息
          Ort::AllocatorWithDefaultOptions allocator;
          size_t num_input_nodes = session.GetInputCount();
          std::vector<const char*> input_node_names(num_input_nodes);
          std::vector<int64_t> input_node_dims;
          for (int i = 0; i < num_input_nodes; i++) {
              // 获取输入节点名称和维度
              char* input_name = session.GetInputName(i, allocator);
              input_node_names[i] = input_name;
              Ort::TypeInfo type_info = session.GetInputTypeInfo(i);
              auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
              input_node_dims = tensor_info.GetShape();
          }
      
          // 创建输入张量
          std::vector<float> input_tensor_values(10);
          std::vector<Ort::Value> input_tensors;
          input_tensors.push_back(Ort::Value::CreateTensor<float>(allocator.GetInfo(), input_tensor_values.data(), 10, input_node_dims.data(), 2));
      
          // 进行推理
          auto output_tensors = session.Run(Ort::RunOptions{nullptr}, input_node_names.data(), input_tensors.data(), num_input_nodes, nullptr, 0);
      
          // 获取输出结果
          float* floatarr = output_tensors[0].GetTensorMutableData<float>();
          std::cout << *floatarr << std::endl;
      
          return 0;
      }
      
  • Q:

    • 在C++中用LibPytorch和使用onnxruntime各自的优缺点是什么?
  • A:

  • Q:

    • pytorch lightning模型可以转换成torch模型吗?
  • A:

    • 是的,PyTorch Lightning 模型可以转换为普通的 PyTorch 模型。PyTorch Lightning 是一个用于构建、训练和调优 PyTorch 模型的高层库,它在 PyTorch 的基础上提供了更简洁、更易用的 API。由于 PyTorch Lightning 模型本质上仍然是 PyTorch 模型,因此您可以直接将其转换为普通的 PyTorch 模型。

      要将 PyTorch Lightning 模型转换为普通的 PyTorch 模型,您只需要获取 PyTorch Lightning 模型的 model 属性即可。例如:

      import pytorch_lightning as pl
      import torch
      
      class MyLightningModel(pl.LightningModule):
          def __init__(self):
              super().__init__()
              self.l1 = torch.nn.Linear(28 * 28, 10)
      
          def forward(self, x):
              return torch.relu(self.l1(x.view(x.size(0), -1)))
      
      # 创建 PyTorch Lightning 模型
      lightning_model = MyLightningModel()
      
      # 转换为普通的 PyTorch 模型
      pytorch_model = lightning_model.model
      

      在上面的示例中,我们首先定义了一个简单的 PyTorch Lightning 模型 MyLightningModel,然后使用 lightning_model.model 获取了对应的普通 PyTorch 模型。

    感叹一下,在大模型的加持下,我们只要学会如何提问就好了,GPT真的会帮我们找到方法。

我的方案

根据GPT的提示,将PL转为正常的torch

其实就多了一句话:

# 加载PL NN模型
model = ABC(xxxx)
# 加载权重
model.load_state_dict(ckptfile)
model_torch = model.model
# torch.jit转换模型
f = torch.jit.script(model_torch)
# 保存为xxx.pt文件
f.save(pt_file)

搞定!
如果帮助到你,给个点赞+收藏+关注

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值