一、CMakeLists.txt
cmake_minimum_required(VERSION 3.0.0)
project(c_python_test VERSION 0.1.0)
if(CMAKE_COMPILER_IS_GNUCC)
message("COMPILER IS GNUCC")
ADD_DEFINITIONS ( -std=c++11 )
endif(CMAKE_COMPILER_IS_GNUCC)
#SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -Wall -ggdb3")
SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O3 -Wall")
# 添加头文件路径
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include_directories(/home/hc/miniconda3/envs/mjh/include/python3.8) # 虚拟环境python头文件
# 添加链接库
LINK_DIRECTORIES(/home/hc/miniconda3/envs/mjh/lib) #虚拟环境中python库的文件夹
LINK_LIBRARIES(python3.8)
# 添加要编译的可执行文件
add_executable(c_python_test c_call_python.cpp)
# 隐式链接库文件
# target_link_libraries(${PROJECT_NAME} python3.8)
#target_link_libraries(${PROJECT_NAME} track.cpython-38-x86_64-linux-gnu.so)
# 开启调试
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g")
message($(CMAKE_CXX_FLAGS))
二、python代码
import onnxruntime
import onnx
import time
import scipy.io as scio
import numpy as np
import time
NUM_CLASSES = 4
BATCH_SIZE = 64
onnx_model_path = "/home/hc/cz_ori/models/model_v1.onnx"
onnx_session = onnxruntime.InferenceSession(onnx_model_path,providers = ['CUDAExecutionProvider'])
print(onnx_session.get_providers())
model = onnx.load('/home/hc/cz_ori/models/model_v1.onnx')
onnx.checker.check_model(model)
def test_open_set(PDW):
pdw = np.array(PDW).astype(np.float32)
test_images = pdw.reshape(1,100,3)
start_time = time.time()
inputs = onnx_session.get_inputs()[0].name
outs = onnx_session.run(None,{inputs:test_images})[0]
softmax_probs = np.exp(outs) / np.sum(np.exp(outs), axis=1, keepdims=True)
preds_cls = np.argmax(softmax_probs, axis=1)
predicted_cls = preds_cls.item()
print("预测标签:", predicted_cls)
end_time = time.time()
execution_time = end_time - start_time
print("execution time:", execution_time)
return predicted_cls
三、C++代码
//c_call_python.cpp
#include <python3.8/Python.h>
#include <iostream>
#include <vector>
using namespace std;
int main(){
//初始化python环境
Py_Initialize();
if(!Py_IsInitialized()){
printf("python init fail\n");
return 0;
}
//PyRun_SimpleString()执行命令语句
//测试python3的打印语句
PyRun_SimpleString("print('Hello Python!')\n");
PyRun_SimpleString("import os,sys");//执行import语句,把当前路径加入路径中,为了找到onnx_execution.py
PyRun_SimpleString("sys.path.append('./')");
PyRun_SimpleString("print(os.getcwd())");//测试打印当前路径
std::vector<float>pdw = {
84, 15, 3,
84, 27, 3,
84, 16, 3,
84, 35, 3,
84, 15, 3,
84, 15, 3,
84, 27, 3,
84, 16, 3,
84, 35, 3,
84, 15, 3,
84, 15, 3,
84, 27, 3,
84, 16, 3,
84, 35, 3,
84, 15, 3,
84, 15, 3,
84, 27, 3,
84, 16, 3,
84, 35, 3,
84, 15, 3,
84, 15, 3,
84, 27, 3,
84, 16, 3,
84, 35, 3,
84, 15, 3,
84, 15, 3,
84, 27, 3,
84, 16, 3,
84, 35, 3,
84, 15, 3,
84, 15, 3,
84, 27, 3,
84, 16, 3,
84, 35, 3,
84, 15, 3,
84, 15, 3,
84, 27, 3,
84, 16, 3,
84, 35, 3,
84, 15, 3,
84, 15, 3,
84, 27, 3,
84, 16, 3,
84, 35, 3,
84, 15, 3,
84, 15, 3,
84, 27, 3,
84, 16, 3,
84, 35, 3,
84, 15, 3,
84, 15, 3,
84, 27, 3,
84, 16, 3,
84, 35, 3,
84, 15, 3,
84, 15, 3,
84, 27, 3,
84, 16, 3,
84, 35, 3,
84, 15, 3,
84, 15, 3,
84, 27, 3,
84, 16, 3,
84, 35, 3,
84, 15, 3,
84, 15, 3,
84, 27, 3,
84, 16, 3,
84, 35, 3,
84, 15, 3,
84, 15, 3,
84, 27, 3,
84, 16, 3,
84, 35, 3,
84, 15, 3,
84, 15, 3,
84, 27, 3,
84, 16, 3,
84, 35, 3,
84, 15, 3,
84, 15, 3,
84, 27, 3,
84, 16, 3,
84, 35, 3,
84, 15, 3,
84, 15, 3,
84, 27, 3,
84, 16, 3,
84, 35, 3,
84, 15, 3,
84, 15, 3,
84, 27, 3,
84, 16, 3,
84, 35, 3,
84, 15, 3,
84, 15, 3,
84, 27, 3,
84, 16, 3,
84, 35, 3,
84, 15, 3,
};
// 创建一个Python元组以保存浮点数
PyObject *pList = PyTuple_New(pdw.size());
if (!pList) {
PyErr_Print();
printf("Failed to create PyTuple!\n");
Py_Finalize();
return -1;
}
// 将数据填充到Python元组中
for (size_t i = 0; i < pdw.size(); ++i) {
PyObject *pValue = PyFloat_FromDouble(pdw[i]);
if (!pValue) {
PyErr_Print();
printf("Failed to convert float to PyObject!\n");
Py_DECREF(pList);
Py_Finalize();
return -1;
}
if (PyTuple_SetItem(pList, i, pValue) != 0) {
PyErr_Print();
printf("Failed to set item in PyTuple!\n");
Py_DECREF(pList);
Py_Finalize();
return -1;
}
}
//调用onnx_cpp脚本
PyObject* pModule = PyImport_ImportModule("onnx_cpp");
if (!pModule) {
printf("import python failed1!!\n");
return -1;
}
//查找对应onnx_execution.py中的def test_open_set函数
PyObject* pFunction = PyObject_GetAttrString(pModule, "test_open_set");
if (!pFunction) {
printf("get python function failed!!!\n");
return -1;
}
PyObject *pArgs = PyTuple_Pack(1,pList);
PyObject *pResult = PyObject_CallObject(pFunction, pArgs);
if (!pResult) {
PyErr_Print();
printf("Function call failed!\n");
Py_DECREF(pFunction);
Py_DECREF(pModule);
Py_DECREF(pList);
Py_Finalize();
return -1;
}
long result = PyLong_AsLong(pResult);
printf("Result is %ld\n", result);
Py_DECREF(pResult);
Py_DECREF(pFunction);
Py_DECREF(pModule);
Py_Finalize();
return 0;
}
四、编译C++代码
mkdir build
cd build
cmake ..
make