# 用C++ API训练tensorflow模型

1、利用python构建graph，代码如下：

（更新2018/1/5， 可以用init = tf.group(tf.global_variables_initializer(), name = 'init') 来解决）

import tensorflow as tf
-Python 代码
01
with tf.Session() as sess:
02
x = tf.placeholder(tf.float32, [None, 32], name="x")
03
y = tf.placeholder(tf.float32, [None, 8], name="y")
04

05
w1 = tf.Variable(tf.truncated_normal([32, 16], stddev=0.1))
06
b1 = tf.Variable(tf.constant(0.0, shape=[16]))
07

08
w2 = tf.Variable(tf.truncated_normal([16, 8], stddev=0.1))
09
b2 = tf.Variable(tf.constant(0.0, shape=[8]))
10

11
12
y_out = tf.nn.tanh(tf.nn.bias_add(tf.matmul(a, w2), b2), name="y_out")
13
cost = tf.reduce_sum(tf.square(y-y_out), name="cost")
14
15

16
init = tf.initialize_variables(tf.all_variables(), name='init_all_vars_op')
17
tf.train.write_graph(sess.graph_def,
18
'./',
19
'mlp.pb', as_text=False)

2、编写c++代码读取pb文件，并读取数据，开始训练

-Cpp 代码
01
#include "tensorflow/core/public/session.h"
02
#include "tensorflow/core/graph/default_device.h"
03
using namespace tensorflow;
04

05
int main(int argc, char* argv[]) {
06

07
std::string graph_definition = "mlp.pb";
08
Session* session;
09
GraphDef graph_def;
10
SessionOptions opts;
11
std::vector<Tensor> outputs; // Store outputs
12
13

14
// Set GPU options
15
graph::SetDefaultDevice("/gpu:0", &graph_def);
16
opts.config.mutable_gpu_options()->set_per_process_gpu_memory_fraction(0.5);
17
opts.config.mutable_gpu_options()->set_allow_growth(true);
18

19
// create a new session
20
TF_CHECK_OK(NewSession(opts, &session));
21

22
23
TF_CHECK_OK(session->Create(graph_def));
24

25
// Initialize our variables
26
TF_CHECK_OK(session->Run({}, {}, {"init_all_vars_op"}, nullptr));
27

28
Tensor x(DT_FLOAT, TensorShape({100, 32}));
29
Tensor y(DT_FLOAT, TensorShape({100, 8}));
30
auto _XTensor = x.matrix<float>();
31
auto _YTensor = y.matrix<float>();
32

33
_XTensor.setRandom();
34
_YTensor.setRandom();
35

36
for (int i = 0; i < 10; ++i) {
37

38
TF_CHECK_OK(session->Run({{"x", x}, {"y", y}}, {"cost"}, {}, &outputs)); // Get cost
39
float cost = outputs[0].scalar<float>()(0);
40
std::cout << "Cost: " <<  cost << std::endl;
41
TF_CHECK_OK(session->Run({{"x", x}, {"y", y}}, {}, {"train"}, nullptr)); // Train
42
outputs.clear();
43
}
44

45

46
session->Close();
47
delete session;
48
return 0;
49
}

3、编译运行

3.1 使用libtensorflow.so库 + gcc编译

-Bash 代码
1
gcc -std=c++11 -I /usr/local/include/tf -L /usr/local/lib train.cc -ltensorflow

/usr/local/include/tf/third_party/eigen3/unsupported/Eigen/CXX11/Tensor:1:42: fatal error: unsupported/Eigen/CXX11/Tensor: No such file or directory

3.2 用bazel的方式编译

-Bash 代码
01
02

03
tf_cc_binary(
04
name = "train_inCpp",
05
srcs = ["train_inCpp.cc"],
06
deps = [
07
"//tensorflow/cc:cc_ops",
08
"//tensorflow/cc:client_session",
09
"//tensorflow/core:tensorflow",
10
],
11
)

build完成后，就会在tensorflow-master/bazel-bin/tensorflow/heke/train_with_cpp/cpp目录下生成可执行文件，直接运行即可 ./train

4、移植到其它机器上运行

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/tutorials/example_trainer.cc

https://tebesu.github.io/posts/Training-a-TensorFlow-graph-in-C++-API

https://www.tensorflow.org/api_guides/cc/guide

https://matrices.io/training-a-deep-neural-network-using-only-tensorflow-c/

• 广告
• 抄袭
• 版权
• 政治
• 色情
• 无意义
• 其他

120