C++调用python训练的pytorch模型(三)----- 实战:封装pytorch模型

8 篇文章 1 订阅
5 篇文章 0 订阅

封装python 模型 SDK
准备好python api函数

python代码

# webcam_test.py
global g_model
def load_model(wkspace_dir,cfg_file):
    # prepare object that handles inference plus adds predictions on top of image
    global g_model
    print("wkspace_dir: %s" % wkspace_dir)
    print("cfg_file: %s" % cfg_file)
    # os.chdir('/home/bob/wkspace/git/maskrcnn-benchmark/demo')
    os.chdir(wkspace_dir)
    # load config from file and command-line arguments
    # cfg.merge_from_file("r50_1204.yaml")
    cfg.merge_from_file(cfg_file)
    # cfg.merge_from_list(args.opts)
    cfg.freeze()
    coco_demo = COCODemo(
        cfg,
        confidence_threshold=0.7,
        show_mask_heatmaps=False,
        masks_per_dim=2,
        min_image_size=480,
    )
    g_model = coco_demo

def forward(image):
    global g_model
    print('image path is: %s' %image)
    image=cv2.imread(image)
    predictions = g_model.compute_prediction(image)
    predictions = g_model.select_top_predictions(predictions)
    scores = predictions.get_field("scores").tolist()
    labels = predictions.get_field("labels").tolist()
    labels = [g_model.CATEGORIES[i] for i in labels]
    boxes = predictions.bbox
    template = "{}: {:.2f}"
    for box, score, label in zip(boxes, scores, labels):
        x, y = box[:2]
        s = template.format(label, score)
        # cv2.putText(image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1)
        if label is None:
            return ""
        label = g_model.CATEGORIES.index(label)
        bbox_str = ' '.join(str('%.2f,' % i) for i in box)
        bbox_str = bbox_str.rstrip(',')
        result = str(label) + ',' + str('%.2f' % score) + ',' + bbox_str
        print(result)
        return result
    return ""

if __name__ == "__main__":
    # main()
    load_model()
    forward('/home/bob/wkspace/git/maskrcnn-benchmark/demo/183101.jpg')
C++调用python api

hands_detect.h

#include <Python.h>
#include <memory>
#include <iostream>
#include <vector>

using namespace std;

namespace handsDetect{
    /**
    *  功能描述:网络初始化,加载模型。
    *  @param wkspace_path     :python工作目录必须设置maskrcnn/demo目录;
    *  @param cfg_file  :配置文件的路径,可以相对demo的路径;
    *  @return  0-成功;-1-失败;
    */
int Initialize(const char* wkspace_path,const char* cfg_file);
	 /**
	    *  功能描述:执行网络前向预测。
	    *  @param image_file    :网络输入一张图片;
	    *  @param output        :网络输出,长度为6的数组:依次保持 [class,score,bbox[4]]
	    *  @return  0-成功;-1-失败;
	   */
int  Forward(const char* image_file, vector<float> &output);
    /**
    *  功能描述:释放内存
    *  @return  0-成功;-1-失败;
    */
void Uninitialize();
}

hands_detect.cpp

#include "hands_detect.h"

PyObject * g_pModule = NULL;
PyObject * g_pFunc_init = NULL;
PyObject * g_pFunc_forward = NULL;

namespace handsDetect{
int Initialize(const char* wkspace_path,const char* cfg_file)
{
	if(wkspace_path == nullptr || cfg_file == nullptr){
		cout<<"please check wkspace path or cfg file path"<<endl;
		return -1;
	}
	Py_SetPythonHome(L"/home/bob/anaconda2/envs/benchmark_py36");
	Py_Initialize();
	if(!Py_IsInitialized())
	{
		cout<<"Error: python init failed!"<<endl;
		return -1;
	}
	else
	{
		cout<<"python init successful"<<endl;
	}
	PyRun_SimpleString("import sys");
	char buf[200];
	//char str[]="/home/bob/wkspace/git/maskrcnn-benchmark/demo/";
	sprintf(buf,"sys.path.append('%s')",wkspace_path);
	cout<< "wkspace_path:"<<buf<<endl;
	//PyRun_SimpleString("sys.path.append('/home/bob/wkspace/git/maskrcnn-benchmark/demo/')");
	PyRun_SimpleString(buf);

	g_pModule = PyImport_ImportModule("webcam_test");
	if(g_pModule == nullptr)
	{
		cout<<"Error: python module is null!"<<endl;
		return -1;
	}
	else
	{
		cout<<"python module is successful"<<endl;
	}
	
	g_pFunc_init = PyObject_GetAttrString(g_pModule,"load_model");
	if(g_pFunc_init == nullptr)
	{
		cout<<"Error: python pFunc_Initialize is null!"<<endl;
		return -1;
	}
	else
	{
		cout<<"python pFunc_Initialize is successful"<<endl;
	}
	g_pFunc_forward = PyObject_GetAttrString(g_pModule,"forward");
	if(g_pFunc_forward == nullptr)
	{
		cout<<"Error: python pFunc_forward is null!"<<endl;
		return -1;
	}
	else
	{
		cout<<"python pFunc_forward is successful"<<endl;
	}

	PyObject *pArgs = PyTuple_New(2);//函数调用的参数传递均是以元组的形式打包的,2表示参数个数
	PyTuple_SetItem(pArgs, 0, Py_BuildValue("s",wkspace_path));//0--序号
	PyTuple_SetItem(pArgs, 1, Py_BuildValue("s",cfg_file));//1--序号
	PyEval_CallObject(g_pFunc_init, pArgs);
	return 0;
}

void split_string(const string& s, vector<string>& v, const string& c)
{
    string::size_type pos1, pos2;
    pos2 = s.find(c);
    pos1 = 0;
    while(string::npos != pos2)
    {
        v.push_back(s.substr(pos1, pos2-pos1));

        pos1 = pos2 + c.size();
        pos2 = s.find(c, pos1);
    }
    if(pos1 != s.length())
        v.push_back(s.substr(pos1));
}

int string_to_data(char* src_str, vector<float> &output)
{
    //char *src_str= "1.1,2.2,3.3,4.44,5.5555,6.666";
    string s = src_str;
    vector<string> v;
    split_string(s, v,","); //可按多个字符来分隔;
    for(vector<string>::size_type i = 0; i != v.size(); ++i)
    {
        //cout <<"v:"<< v[i] << endl;
//        cout <<"o:"<< atof(v[i].c_str()) << endl;
    	output.push_back(atof(v[i].c_str()));
    }

    return 0;
}
int  Forward(const char* image_file, vector<float> &output)
{
	PyObject *args = Py_BuildValue("(s)",image_file);
	PyObject *pRet = PyObject_CallObject(g_pFunc_forward,args);
	if(pRet == nullptr)
	{
		cout<<"Error: python pFunc_forward pRet is null!"<<endl;
		return -1;
	}	
	else
	{
		cout<<"python pFunc_forward pRet is successful"<<endl;
	}
	char* res;
	PyArg_Parse(pRet,"s",&res);
	cout<<"res:"<<res<<endl;
	string_to_data(res, output);
	if (6 != output.size()) {
	     cout <<"Error: output size is "<<output.size()<< endl;
		return -1;
	}
	return 0;
}
void Uninitialize()
{
	Py_DECREF(g_pFunc_forward);
	Py_DECREF(g_pFunc_init);
	Py_DECREF(g_pModule);
	Py_Finalize();
}
}
生成so文件

bash命令行

g++ hands_detect.cpp  -fPIC -shared -o libhandsdetect.so -std=c++11 \
-I/home/bob/anaconda2/envs/benchmark_py36/include/python3.6m/  \
-L/home/bob/anaconda2/envs/benchmark_py36/lib/ -lpytho3.6m

makefile

# source object target
SRCS   := hands_detect.cpp
OBJS   := hands_detect.o
TARGET := libhandsdetect.so

# compile and lib parameter
CC      := g++
CFLAGS  := -Wall -g
LIBS    := -lpython3.6m
LIB_DIR := -L/home/bob/anaconda2/envs/benchmark_py36/lib/
DEFINES :=
INCLUDE := -I/home/bob/anaconda2/envs/benchmark_py36/include/python3.6m/

all:
	$(CC) $(INCLUDE) $(CFLAGS) -fPIC -shared -std=c++11 $(SRCS) -o $(TARGET) $(LIB_DIR) $(LIBS)
	
# all:
# $(CC) -o $(TARGET) $(SOURCE)

clean:
	rm -fr *.o $(TARGET)
调用模型SDK
demo
#include "hands_detect.h"
#include <vector>
using namespace std;
namespace hd=handsDetect;
int main()
{
	vector<float> output;
	if(hd::Initialize("/home/bob/wkspace/git/maskrcnn-benchmark/demo/","r50_1204.yaml") < 0){
		cout<< "hands detect initialize is failed"<<endl;
	}

	if(hd::Forward("/home/bob/wkspace/git/maskrcnn-benchmark/demo/183101.jpg",output) < 0){
		cout<< "hands detect forward is failed"<<endl;
	}
	cout<<"output: "<< output.size()<<endl;
	for(vector<float>::size_type i = 0; i != output.size(); ++i)
    	{
        	cout <<output[i] << " ";
    	}
	cout<<endl;
	hd::Uninitialize();
	return 0;
}
makefile
# source object target
SOURCE := demo.cpp
OBJS   := demo.o
TARGET := demo

# compile and lib parameter
CC      := g++
LIBS    := -lhandsdetect
LDFLAGS := -L. \
		-L./lib \
	   -L/home/bob/anaconda2/envs/benchmark_py36/lib/
DEFINES :=
INCLUDE := -I/home/bob/anaconda2/envs/benchmark_py36/include/python3.6m/ \
	   -I./lib
CFLAGS  := -Wl,-rpath=/home/bob/anaconda2/envs/benchmark_py36/lib/ -Wl,-rpath=./lib
CXXFLAGS:=

# link
$(TARGET):$(OBJS)
	$(CC) -o $@ $^ $(LDFLAGS) $(CFLAGS)  $(LIBS)

# compile
$(OBJS):$(SOURCE)
	$(CC) $(INCLUDE) -g -c $^ -o $@

# all:
# $(CC) -o $(TARGET) $(SOURCE)

clean:
	rm -fr *.o $(TARGET)
执行demo
cd lib
make clean
make
cd ..
make clean
make
./demo
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值