c++调用tensorflow训练好的.pb模型

c++调用tensorflow训练好的.pb模型

try {
        //设置python环境,如果电脑里有多个环境需要指定一下,只有一个可以不用设置,程序会到系统环境变量里面找
		Py_SetPythonHome(L"D:\\software\\Anaconda3\\envs\\tensorflow-1.12");
        //初始化python解释器
		Py_Initialize();
		PyEval_InitThreads();
		PyObject* pFunc = NULL;//python调用函数接口
		PyObject* pArg = PyTuple_New(9);//初始化要传递给python的参数个数
		PyObject* module = NULL;

		PyRun_SimpleString("import sys");
		PyRun_SimpleString("sys.path.append('./')");//这一步很重要,修改Python路径
		module = PyImport_ImportModule("train");//myModel:Python文件名  
		if (!module) {
			printf("cannot open module!");
			Py_Finalize();
			//return;
		}
		pFunc = PyObject_GetAttrString(module, "train");//test_one_image:Python文件中的函数名  
		if (!pFunc) {
			printf("cannot open FUNC!");
			Py_Finalize();
			//return;
		}

		CString data_path = param.data_path.c_str();
		CString model_dir = param.save_model_path.c_str();
		int blanced_num = param.balance_num;
		int epoch_num = param.epoch_num;
		int batch_size = param.batch_size;
		int class_num = param.class_num;
		model_type type = param.type;
		int is_socket = 0;
		 
		string task;
		if (type == 0)
			task = "classification";
		else if (type == 1)
			task = "segmentation";
		else if (type == 2)
			task = "detection";


		int list_num = param.class_weight.size();
		PyObject *PyList = PyList_New(list_num);
		for (int i = 0; i < list_num; i++)
			PyList_SetItem(PyList, i, Py_BuildValue("f", param.class_weight[i]));
	
		PyTuple_SetItem(pArg, 0, Py_BuildValue("s", param.save_model_path.c_str()));
		PyTuple_SetItem(pArg, 1, Py_BuildValue("i", batch_size));
		PyTuple_SetItem(pArg, 2, Py_BuildValue("i", epoch_num));
		PyTuple_SetItem(pArg, 3, PyList);
		PyTuple_SetItem(pArg, 4, Py_BuildValue("i", class_num));
		PyTuple_SetItem(pArg, 5, Py_BuildValue("i", blanced_num));
		PyTuple_SetItem(pArg, 6, Py_BuildValue("s", param.data_path.c_str()));
		PyTuple_SetItem(pArg, 7, Py_BuildValue("s", task.c_str()));
		PyTuple_SetItem(pArg, 8, Py_BuildValue("i", is_socket));
		
		if (module != NULL) 
		{
			PyGILState_STATE gstate;
			gstate = PyGILState_Ensure();
			PyEval_CallObject(pFunc, pArg);//调用函数,传递参数
			//PyGILState_Release(gstate);
		}
	}
	catch (exception& e)
	{
		cout << "Standard exception: " << e.what() << endl;
	}


	

 

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值