用C++ API训练tensorflow模型

在前面的博客中,已经从源码安装了tensorflow,能够成功编译c++的代码,那么就可以编写c++的代码编写tensorflow的模型,并训练模型。这里可以参考https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/tutorials/example_trainer.cc#L49完成利用c++代码完成模型的构建和训练。但按照google上的说法,Auto-differentiation(自动微分,自动求导)功能不够完善,很多在c++的API中还没有集成(https://github.com/tensorflow/tensorflow/issues/4130)。


找到另外一篇博客(https://tebesu.github.io/posts/Training-a-TensorFlow-graph-in-C++-API)这里介绍另外一种用c++ API训练模型的方式。首先也是需要用python撰写网络结构,但是每个节点都需要命名,在c++中直接运行类似sess.run(节点名)的代码即可。


1、利用python构建graph,代码如下:


在Python的代码中,运行时,会提醒tf.initialize_variables这个函数已经被放弃了,可以改用tf.global_variables_initializer(),但是我还没找到如何在c++调用这个初始化(不能给这个操作命名,添加name),并且继续用tf.initialize_variables,可以正常运行。故目前继续采用这种方式。


(更新2018/1/5, 可以用init = tf.group(tf.global_variables_initializer(), name = 'init') 来解决)




import tensorflow as tf
-Python 代码
01
with tf.Session() as sess:
02
    x = tf.placeholder(tf.float32, [None, 32], name="x")
03
    y = tf.placeholder(tf.float32, [None, 8], name="y")
04
 
05
    w1 = tf.Variable(tf.truncated_normal([32, 16], stddev=0.1))
06
    b1 = tf.Variable(tf.constant(0.0, shape=[16]))
07
 
08
    w2 = tf.Variable(tf.truncated_normal([16, 8], stddev=0.1))
09
    b2 = tf.Variable(tf.constant(0.0, shape=[8]))
10
 
11
    a = tf.nn.tanh(tf.nn.bias_add(tf.matmul(x, w1), b1))
12
    y_out = tf.nn.tanh(tf.nn.bias_add(tf.matmul(a, w2), b2), name="y_out")
13
    cost = tf.reduce_sum(tf.square(y-y_out), name="cost")
14
    optimizer = tf.train.AdamOptimizer().minimize(cost, name="train")
15
 
16
    init = tf.initialize_variables(tf.all_variables(), name='init_all_vars_op')
17
    tf.train.write_graph(sess.graph_def,
18
                         './',
19
                         'mlp.pb', as_text=False)


2、编写c++代码读取pb文件,并读取数据,开始训练


代码如下:
-Cpp 代码
01
#include "tensorflow/core/public/session.h"
02
#include "tensorflow/core/graph/default_device.h"
03
using namespace tensorflow;
04
 
05
int main(int argc, char* argv[]) {
06
 
07
    std::string graph_definition = "mlp.pb";
08
    Session* session;
09
    GraphDef graph_def;
10
    SessionOptions opts;
11
    std::vector<Tensor> outputs; // Store outputs
12
    TF_CHECK_OK(ReadBinaryProto(Env::Default(), graph_definition, &graph_def));
13
 
14
    // Set GPU options
15
    graph::SetDefaultDevice("/gpu:0", &graph_def);
16
    opts.config.mutable_gpu_options()->set_per_process_gpu_memory_fraction(0.5);
17
    opts.config.mutable_gpu_options()->set_allow_growth(true);
18
 
19
    // create a new session
20
    TF_CHECK_OK(NewSession(opts, &session));
21
 
22
    // Load graph into session
23
    TF_CHECK_OK(session->Create(graph_def));
24
 
25
    // Initialize our variables
26
    TF_CHECK_OK(session->Run({}, {}, {"init_all_vars_op"}, nullptr));
27
 
28
    Tensor x(DT_FLOAT, TensorShape({100, 32}));
29
    Tensor y(DT_FLOAT, TensorShape({100, 8}));
30
    auto _XTensor = x.matrix<float>();
31
    auto _YTensor = y.matrix<float>();
32
 
33
    _XTensor.setRandom();
34
    _YTensor.setRandom();
35
 
36
    for (int i = 0; i < 10; ++i) {
37
 
38
        TF_CHECK_OK(session->Run({{"x", x}, {"y", y}}, {"cost"}, {}, &outputs)); // Get cost
39
        float cost = outputs[0].scalar<float>()(0);
40
        std::cout << "Cost: " <<  cost << std::endl;
41
        TF_CHECK_OK(session->Run({{"x", x}, {"y", y}}, {}, {"train"}, nullptr)); // Train
42
        outputs.clear();
43
    }
44
 
45
 
46
    session->Close();
47
    delete session;
48
    return 0;
49
}


3、编译运行


编译运行c++代码的时候,跟前面博客介绍的一样,有两种方式,一是直接用bazel编译,二是用so库+gcc的方式编译。


3.1 使用libtensorflow.so库 + gcc编译
在某个文件下,比如/home/heke/test/train_with_cpp下,保存步骤1的代码为build_graph.py,运行python代码,得到mlp.pb文件。
保存步骤2的代码为train.cc


输入编译的命令
-Bash 代码
1
gcc -std=c++11 -I /usr/local/include/tf -L /usr/local/lib train.cc -ltensorflow
但出现了如下问题(暂未解决):
/usr/local/include/tf/third_party/eigen3/unsupported/Eigen/CXX11/Tensor:1:42: fatal error: unsupported/Eigen/CXX11/Tensor: No such file or directory




3.2 用bazel的方式编译


在tensorflow-master/tensorflow/heke目录下,新建train_with_cpp目录,再在train_with_cpp目录下新建cpp和py两个文件夹。
保存步骤1的代码到py文件夹下,运行build_graph.py,得到mlp.pb文件
保存步骤2的代码到cpp文件夹下,得到train.cc文件
新建BUILD文件,粘贴如下代码:
-Bash 代码
01
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
02
 
03
tf_cc_binary(
04
    name = "train_inCpp",
05
    srcs = ["train_inCpp.cc"],
06
    deps = [
07
        "//tensorflow/cc:cc_ops",
08
        "//tensorflow/cc:client_session",
09
        "//tensorflow/core:tensorflow",
10
    ],
11
)


然后在cpp目录下,运行 bazel build :train


build完成后,就会在tensorflow-master/bazel-bin/tensorflow/heke/train_with_cpp/cpp目录下生成可执行文件,直接运行即可 ./train
出现类似下面的cost






4、移植到其它机器上运行


把cpp整个目录复制到其它机器上,并把libtensorflow_cc.so和libtensorflow.so、libtensorflow_framework.so放到/usr/local/lib下,就可以运行。


参考博客:


https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/tutorials/example_trainer.cc


https://medium.com/jim-fleming/loading-a-tensorflow-graph-with-the-c-api-4caaff88463f


https://tebesu.github.io/posts/Training-a-TensorFlow-graph-in-C++-API


https://www.tensorflow.org/api_guides/cc/guide


https://matrices.io/training-a-deep-neural-network-using-only-tensorflow-c/
要部署 TensorFlow 模型,可以按照以下步骤进行操作: 1. 准备模型:首先,需要训练和保存 TensorFlow 模型。可以使用 TensorFlow 提供的高级 API,如 Keras,或使用原生 TensorFlow API 进行模型训练。在训练完成后,保存模型的权重和结构。 2. 导出模型:将模型导出为 TensorFlow 支持的格式,如 SavedModel 或 TensorFlow 格式(.pb)。这样做可以确保在部署过程中可以轻松加载模型。导出模型时,记得保存模型的元数据和签名。 3. 安装 TensorFlow 和相关库:在部署 TensorFlow 模型之前,需要在目标环境中安装 TensorFlow 和其他必要的软件库。可以使用 pip 或 conda 进行安装,并确保使用与训练时相同的 TensorFlow 版本。 4. 加载和推理:在部署环境中,导入 TensorFlow 和相关库,并加载导出的模型。使用输入数据对模型进行推理,并获取输出结果。可以通过 TensorFlow 提供的预测函数直接进行推理,或使用 TensorFlow Serving 等工具进行更高级的模型部署。 5. 部署到服务器或云平台:如果要在服务器或云平台上部署 TensorFlow 模型,可以使用诸如 Docker 和 Kubernetes 的容器化技术。这样可以将模型包装为容器,并提供可扩展的部署解决方案。 6. 性能优化:在部署期间,可以进行一些性能优化以提高模型的推理速度和效率。例如,使用 TensorFlow Lite 将模型转换为适用于移动设备或嵌入式设备的优化版本,或使用 TensorFlow GPU 支持利用 GPU 加速模型推理。 总之,部署 TensorFlow 模型需要准备模型、导出模型、安装所需库、加载和推理模型,并根据实际需求选择合适的部署方式。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值