三种部署Pytorch模型到C++环境的方式
文章目录
前言
由于工作原因需要部署Pytorch模型到c++环境下,目前大概有三种方式。
1、pytorch转成onnx文件后,通过opencv读取。
2、pytroch转成onnx文件后,通过onnxruntime读取。
3、利用libtorch库,也就是pytorch的c++版。
一、pytorch2onnx
首先的将pytorch训练好的模型导出onnx文件。
安装所需包:
pip install onnx
pip install onnxruntime
from nets.deeplabv3 import deeplabv3 #这里导入自己的模型
import torch
import os
from PIL import Image
import numpy as np
import onnx
import onnxruntime
def preprocess_input(image):
image /= 255.0
return image
def cvtColor(image):
if len(np.shape(image)) == 3 and np.shape(image)[-2] == 3:
return image
else:
image = image.convert('RGB')
return image
# 检查输出
def check_onnx_output(filename, input_data, torch_output):
print("模型测试")
session = onnxruntime.InferenceSession(filename)
input_name = session.get_inputs()[0].name
result = session.run([], {
input_name: input_data.detach().cpu().numpy()})
for test_result, gold_result in zip(result, torch_output.values()):
np.testing.assert_almost_equal(
gold_result.cpu().numpy(), test_result, decimal=3,
)
return result
# 检查模型
def check_onnx_model(model, onnx_filename, input_image):
with torch.no_grad():
torch_out = {
"output": model(input_image)}
check_onnx_output(onnx_filename, input_image, torch_out)
print("模型输出一致")
onnx_model = onnx.load(onnx_filename)
onnx.checker.check_model(onnx_model)
print("模型测试成功")
return onnx_model
if __name__ == '__main__':
# 模型路径
model_path = 'net.pth'
onnx_path = os.path.split(model_path)[0] + '/'
device = 'cpu'
# 图片路径
VOCdevkit_path ='./1.jpg'
img = Image.open(VOCdevkit_path)
img = cvtColor(img)
img = np.expand_dims(np.transpose(preprocess_input(np.array(img, np.float32)), (2, 0, 1)), 0)
img = torch.from_numpy(img)
net = deeplabv3 ()
net.load_state_dict(torch.load(model_path, map_location=device), strict=True)
net = net.eval()
out = net(img)
print(out)
torch.onnx.export(net, img, onnx_path + "torch.onnx", verbose=True ,input_names=["input"], output_names=["output"], opset_version=11)
# traced_cpu = torch.jit.trace(net, img)
# torch.jit.save(traced_cpu, onnx_path + "cpu.pt")
# 检测导出的onnx模型是否完整,输出是否一致
onnx_name = onnx_path + "torch.onnx"
onnx_model = check_onnx_model(net, onnx_name, img)
二、三种部署的方式
1.opencv加载onnx
#include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <onnxruntime_cxx_api.h>
#include <fstream>
#include <iostream>
#include <cstdlib>
using namespace std;
int main()
{
String modelFile = "./torch.onnx";
String imageFile = "./1.jpg";
dnn::Net net = cv::dnn::readNetFromONNX(modelFile); //读取网络和参数
// step 1: Read an image in HWC BGR UINT8 format.
cv::Mat imageBGR = cv::imread(input_path