TensorFlow可以使用pip 安装tensorflow包然后调用其python接口,或者使用其C++或者C api进行推理。出于性能或者业务等因素部分用户选择了C/C++接口进行推理,C接口推理tensorflow提供了预编译好的头文件和so(https://www.tensorflow.org/install/lang_c),其缺点是不能调用TensorFlow的C++接口,比较不方便。而C++接口通常需要用户自己重新基于源码编译,费事费力(参考Tensorflow C API 从训练到部署:使用 C API 进行预测和部署 - 技术刘 使用C++调用TensorFlow模型简单说明 | Dannyw's Blog等博客)。
如果开发C++代码,链接pip安装的Tensorflow安装目录下面的so,会报如下错误:
E tensorflow/core/common_runtime/session.cc:67] Not found: No session factory registered for the given session options: {target: "" config: } Registered factories are {}.
同时会发现TensorFlow内部的算子都未注册,即使使用-Wl,--whole-archive处理也无法解决。
那么是否可以实现直接使用pip安装的tensorflow的so和头文件,实现C++接口调用推理呢?作者发现了一个方法并分享如下。
main.cpp推理代码example
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/env.h"
#include <iostream>
#include <string>
using namespace tensorflow;
#ifdef __cplusplus
extern "C" {
#endif
// instantiated in tensorflow _pywrap_tensorflow_internal.so
extern const char* TF_Version(void);
#ifdef __cplusplus
}
#endif
int main() {
// must be called to load op register
TF_Version();
std::string model_path = "resnet_50.pb";
tensorflow::GraphDef graphdef;
tensorflow::Status status_load = ReadBinaryProto(tensorflow::Env::Default(), model_path, &graphdef);
tensorflow::SessionOptions options;
tensorflow::Session* session;
session = tensorflow::NewSession(options);
if (session == nullptr) {
std::cout << "create new session failed" << std::endl;
return -1;
}
tensorflow::Status status;
status = session->Extend(graphdef);
if (!status.ok()) {
std::cout << "session extend graph failed" << std::endl;
return -1;
}
Tensor x(DT_FLOAT, TensorShape({1, 3, 224, 224}));
std::vector<std::pair<std::string, tensorflow::Tensor>> input_tensors;
input_tensors.push_back({"input", x});
std::vector<std::string> output_names = {"resnet_model/stage_1/Relu_2"};
std::vector<Tensor> outputs;
TF_CHECK_OK(session->Run(input_tensors, output_names, {}, &outputs));
// release session
session->Close();
delete session;
session = nullptr;
return 0;
}
这里的核心是调用了TF_Version();(可能其他函数也有类似功效) 从而成功加载so里面的符号,否则并不会加载。具体原因欢迎大家在评论区讨论。这个函数tf 2.x的pip安装包里已经提供了接口定义,而1.1x没有,需要手动定义下。
cmake文件编译选项
核心是需要包含python的so,tf的两个so
project(tf_cpp_test LANGUAGES CXX)
add_compile_options(-fPIC)
# tf version >=1.15 use ABI=0
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
add_executable(
main
main.cpp
)
target_include_directories(
main
PUBLIC
$ENV{TF_INCLUDE_PATH}
$ENV{PYTHON_INCLUDE_PATH}
)
target_link_libraries(
main
PUBLIC
$ENV{TF_SO_FILE}
$ENV{TF_SO_PATH}/python/_pywrap_tensorflow_internal.so
$ENV{PYTHON_SO_FILE}
)
上面的TF_INCLUDE_PATH等可以通过bash脚本获取:
#!/bin/bash
TOOL_SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
export TF_INCLUDE_PATH=$(python3 -c 'import tensorflow as tf; print(tf.sysconfig.get_compile_flags()[0].strip("-I"))')
export TF_SO_PATH=$(python3 -c 'import tensorflow as tf; print(tf.sysconfig.get_link_flags()[0].strip("-L"))')
export TF_SO_FILE=$(ls $TF_SO_PATH/libtensorflow_framework.* |head -1)
export PYTHON_INCLUDE_PATH=$(python3 -c 'import sysconfig; print(sysconfig.get_path("include"))')
export PYTHON_SO_PATH=$(python3 -c 'import sysconfig; print(sysconfig.get_path("stdlib"))')
export PYTHON_SO_FILE=$(find $PYTHON_SO_PATH/../ -name libpython3*.so|head -1)
mkdir ${TOOL_SCRIPT_DIR}/build
cd ${TOOL_SCRIPT_DIR}/build
cmake ..
make
上述代码测试环境:tf1.15+python3.7(基于conda虚拟环境)