概念区分
- jit.script 与 jit.trace:[PyTorch] jit.script 与 jit.trace
简单区分:torch.jit.script 转化不了的模型可以试试torch.jit.trace
推荐使用的是jit.trace
- yolov5中的导出模型使用的便是torch.jit.trace
https://github.com/ultralytics/yolov5/blob/8b18b66304317276f4bfc7cc7741bd535dc5fa7a/export.py
img = torch.zeros(batch_size, 3, *img_size).to(device) # image size(1,3,320,192) iDetection
def export_torchscript(model, img, file, optimize):
# TorchScript model export
prefix = colorstr('TorchScript:')
try:
print(f'\n{prefix} starting export with torch {torch.__version__}...')
f = file.with_suffix('.torchscript.pt')
# 这里导出模型
ts = torch.jit.trace(model, img, strict=False)
(optimize_for_mobile(ts) if optimize else ts).save(f)
print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return ts
except Exception as e:
print(f'{prefix} export failure: {e}')
c++中如何调用?
cpp文件中,<torch/script.h>头文件包含所有LibTorch库文件,main函数接收命令行参数,使用torch::jit::script::Module创建module对象用以加载模型,使用torch::jit::load函数加载命令行参数指定的模型,加载失败时输出error loading the module,成功则输出ok
#include <torch/script.h>
torch::jit::script::Module module;
try {
module = torch::jit::load(model_path);
}