通过C++ API载入tensorflow graph
在tensorflow repo中,和C++相关的tutorial远没有python的那么详尽。这篇文章主要介绍如何利用C++来载入一个预训练好的graph,以便于单独使用或者嵌入到其他app中。
Requirements
安装bazel:tensorflow是使用bazel来进行编译的,所以如果要编译其他需要用到tensorflow的文件,我们就需要用到bazel。关于bazel,如果想要了解更多,可以参考我的另外两篇博客:Bazel入门:编译C++项目,Bazel入门2:C++编译常见用例。
Clone TensorFlow repo。
git clone --recursive https://github.com/tensorflow/tensorflow
构建graph
我们首先创建一个tensorflow graph,然后保存成protobuf备用。
import tensorflow as tf
import numpy as np
with tf.Session() as sess:
a = tf.Variable(5.0, name='a')
b = tf.Variable(6.0, name='b')
c = tf.multiply(a, b, name="c")
sess.run(tf.global_variables_initializer())
print a.eval() # 5.0
print b.eval() # 6.0
print c.eval() # 30.0
tf.train.write_graph(sess.graph_def, 'models/', 'graph.pb', as_text=False)
创建二进制文件
让我们在tensorflow/tensorflow目录下创建一个名叫loader的目录,即tensorflow/tensorflow/loader
,用于载入之前我们创建好的graph。
在loader/
目录下我们再创建一个新的文件叫做loader.cc
。在loader.cc
里我们要做以下几件事情:
- 初始化一个tensorflow session
- 载入之前我们创建好的graph
- 将这个graph加入到session里面
- 设置好输入输出
- 运行graph,得到输出
- 读取输出中的值
- 关闭session,释放资源
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/platform/env.h"
using namespace tensorflow;
int main(int argc, char* argv[]) {
// Initialize a tensorflow session
Session* session;
Status status = NewSession(SessionOptions(), &session);
if (!status.ok()) {
std::cout << status.ToString() << "\n";
return 1;
}
// Read in the protobuf graph we exported
// (The path seems to be relative to the cwd. Keep this in mind
// when using `bazel run` since the cwd isn't where you call
// `bazel run` but from inside a temp folder.)
GraphDef graph_def;
status = ReadBinaryProto(Env::Default(), "models/graph.pb", &graph_def);
if (!status.ok()) {
std::cout << status.ToString() << "\n";
return 1;
}
// Add the graph to the session
status = session->Create(graph_def);
if (!status.ok()) {
std::cout << status.ToString() << "\n";
return 1;
}
// Setup inputs and outputs:
// Our graph doesn't require any inputs, since it specifies default values,
// but we'll change an input to demonstrate.
Tensor a(DT_FLOAT, TensorShape());
a.scalar<float>()() = 3.0;
Tensor b(DT_FLOAT, TensorShape());
b.scalar<float>()() = 2.0;
std::vector<std::pair<string, tensorflow::Tensor>> inputs = {
{ "a", a },
{ "b", b },
};
// The session will initialize the outputs
std::vector<tensorflow::Tensor> outputs;
// Run the session, evaluating our "c" operation from the graph
status = session->Run(inputs, {"c"}, {}, &outputs);
if (!status.ok()) {
std::cout << status.ToString() << "\n";
return 1;
}
// Grab the first output (we only evaluated one graph node: "c")
// and convert the node to a scalar representation.
auto output_c = outputs[0].scalar<float>();
// (There are similar methods for vectors and matrices here:
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/public/tensor.h)
// Print the results
std::cout << outputs[0].DebugString() << "\n"; // Tensor<type: float shape: [] values: 30>
std::cout << output_c() << "\n"; // 30
// Free any resources used by the session
session->Close();
return 0;
}
然后我们需要为我们的项目创建一个BUILD
文件,这会告诉bazel要编译什么东西。在BUILD
文件里我们要定义一个cc_binary
,表示输出一个二进制文件。
cc_binary(
name = "loader",
srcs = ["loader.cc"],
deps = [
"//tensorflow/core:tensorflow",
]
)
那么最终文件结构如下:
- tensorflow/tensorflow/loader/
- tensorflow/tensorflow/loader/loader.cc
- tensorflow/tensorflow/loader/BUILD
编译和运行
- 在tensorflow repo的根目录下,运行./configure
- 在tensorflow/tensorflow/loader目录下,运行bazel build :loader
- 如果编译的时候遇到一大串
undefined reference to ...
的话建议用bazel build —config=monolithic :loader编译,参考https://github.com/tensorflow/tensorflow/issues/13267
- 如果编译的时候遇到一大串
- 在tensorflow repo的根目录下,cd到 bazel-bin/tensorflow/loader目录下
- 将graph protobuf 拷贝到models/graph.pb
- 运行./loader,得到输出!