c++调用tensorflow训练好的.pb模型
try {
//设置python环境,如果电脑里有多个环境需要指定一下,只有一个可以不用设置,程序会到系统环境变量里面找
Py_SetPythonHome(L"D:\\software\\Anaconda3\\envs\\tensorflow-1.12");
//初始化python解释器
Py_Initialize();
PyEval_InitThreads();
PyObject* pFunc = NULL;//python调用函数接口
PyObject* pArg = PyTuple_New(9);//初始化要传递给python的参数个数
PyObject* module = NULL;
PyRun_SimpleString("import sys");
PyRun_SimpleString("sys.path.append('./')");//这一步很重要,修改Python路径
module = PyImport_ImportModule("train");//myModel:Python文件名
if (!module) {
printf("cannot open module!");
Py_Finalize();
//return;
}
pFunc = PyObject_GetAttrString(module, "train");//test_one_image:Python文件中的函数名
if (!pFunc) {
printf("cannot open FUNC!");
Py_Finalize();
//return;
}
CString data_path = param.data_path.c_str();
CString model_dir = param.save_model_path.c_str();
int blanced_num = param.balance_num;
int epoch_num = param.epoch_num;
int batch_size = param.batch_size;
int class_num = param.class_num;
model_type type = param.type;
int is_socket = 0;
string task;
if (type == 0)
task = "classification";
else if (type == 1)
task = "segmentation";
else if (type == 2)
task = "detection";
int list_num = param.class_weight.size();
PyObject *PyList = PyList_New(list_num);
for (int i = 0; i < list_num; i++)
PyList_SetItem(PyList, i, Py_BuildValue("f", param.class_weight[i]));
PyTuple_SetItem(pArg, 0, Py_BuildValue("s", param.save_model_path.c_str()));
PyTuple_SetItem(pArg, 1, Py_BuildValue("i", batch_size));
PyTuple_SetItem(pArg, 2, Py_BuildValue("i", epoch_num));
PyTuple_SetItem(pArg, 3, PyList);
PyTuple_SetItem(pArg, 4, Py_BuildValue("i", class_num));
PyTuple_SetItem(pArg, 5, Py_BuildValue("i", blanced_num));
PyTuple_SetItem(pArg, 6, Py_BuildValue("s", param.data_path.c_str()));
PyTuple_SetItem(pArg, 7, Py_BuildValue("s", task.c_str()));
PyTuple_SetItem(pArg, 8, Py_BuildValue("i", is_socket));
if (module != NULL)
{
PyGILState_STATE gstate;
gstate = PyGILState_Ensure();
PyEval_CallObject(pFunc, pArg);//调用函数,传递参数
//PyGILState_Release(gstate);
}
}
catch (exception& e)
{
cout << "Standard exception: " << e.what() << endl;
}