为了训练好的模型方便在生产环境中使用,可用将python生成的模型转换为C++可使用的模型(.pt模型文件)。然后使用pytorch提供了C++库libtorch,加载该模型。
python模型转换为.pt格式可用使用两种方式:torch.jit.trace 和 torch.jit.script。
torch.jit.trace 和 torch.jit.script 个有优缺点:
torch.jit.trace 加载模型后随便构造一个输入使用模型推理一次,通过跟踪方式将转换模型,优点是简单方便,缺点是如果推理代码中有根据输入走不同分支的判断,该方法不适合。因为它只能导出本次执行过的代码分支。
torch.jit.script 只导入训练好的模型,不需要构造输入执行一遍模型推导。但是 torch.jit.script 脚本不支持python所有的操作,如果遇到问题需要根据错误提示修改python代码。
torch.jit.trace举例:
# model 为你的模型对象,加载了训练好的模型
model.load_state_dict("python-mode.pytorch")
# 使用torch.jit.trace来创建一个trace模型,它将只记录运行模型所需的操作
trace_model = torch.jit.trace(model, torch.rand(1, 1, 3, 256, 256))
# 保存模型
torch.jit.save(trace_model, 'jit-trace.model.pt')
torch.jit.script举例:
# model 为你的模型对象,加载了训练好的模型
model.load_state_dict("python-mode.pytorch")
# 使用torch.jit.script,转换模型
script_model = torch.jit.script(model)
# 保存模型
torch.jit.save(script_model, 'jit-script.model.pt')
torch.jit.script方法中遇到过的问题和处理方法(主要是script不支持的python操作):
1、不支持 python 的 functools.partial 函数
解决,需要使用 functools.partial 转换的函数修改为直接调用原函数的方式。
2、不支持对链表、字典的迭代操作,如 for i in range(...) ,可能会提示这的错误:Expected integer literal for index. ModuleList/Sequential indexing is only supported with integer literals. Enumeration is supported, e.g. 'for index, v in enumerate(self): ...':
解决,根据提示,将循环该成这样:
for index, v in enumerate(你的链表或字典等)
3、有些python基础类型 script 不支持需要转换的,一般根据错误提示转换即可。
模型转换好后使用libtorch加载执行,C++实现代码:
#include <iostream>
#include <torch/script.h>
int main() {
// 加载导出的模型
torch::jit::script::Module module;
try {
module = torch::jit::load("model.pt");
}
catch (const c10::Error& e) {
std::cerr << "Error loading the model\n";
return -1;
}
// 输入张量,这里只是测试
torch::Tensor input = torch::ones({1, 1, 3, 255, 255});
// 运行模型
torch::Tensor output = module.forward({input}).toTensor();
std::cout << "Output: " << output << std::endl;
return 0;
}
编译环境配置等后续补充...