整理C++的tensorflow代码。
1. 对张量的值进行观测。
/*
* inference4beginer.cpp
* Copyright (C) 2017 fisherman
*/
#include <tensorflow/core/public/session.h>
#include <tensorflow/core/platform/env.h>
#include <tensorflow/core/framework/tensor.h>
#include <tensorflow/core/framework/graph.pb.h>
#include <glog/logging.h>
#include <string>
#include <vector>
#include <memory>
int main() {
tensorflow::GraphDef graph;
std::string graph_file = "beginner.pb";
//1 读取Graph, 如果是文本形式的pb,使用ReadTextProto
tensorflow::Status status = tensorflow::ReadBinaryProto(tensorflow::Env::Default(), graph_file, &graph);
if (!status.ok()) {
LOG(FATAL) << status.ToString() << " with graph_file:" << graph_file;
return -1;
}
//2 创建Session
std::unique_ptr<tensorflow::Session> sess;
sess.reset(tensorflow::NewSession({}));
status = sess->Create(graph);
if (!status.ok()) {
LOG(FATAL) << status.ToString();
return -1;
}
//3 构造Tensor并赋值
//3.1 创建tensorflow::Tensor x
tensorflow::Tensor x(tensorflow::DT_FLOAT, tensorflow::TensorShape({}));
//3.2 获取x的Eigen::TensorMap
auto x_map = x.tensor<float, 0>(); // == x.scalar<float>()
//auto ->Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType>, Eigen::Aligned>
//3.3 获取指针, 可以任意复制了。
float* data = x_map.data();
*data = 1.0f;
std::vector<std::pair<std::string, tensorflow::Tensor>> inputs;
inputs.emplace_back(std::make_pair("x", x));
std::vector<std::string> output_tensor_names;
output_tensor_names.emplace_back("y");
std::vector<tensorflow::Tensor> y;
// 4 Session Run
status = sess->Run(inputs, output_tensor_names, {}, &y);
if (!status.ok()) {
LOG(ERROR) << status.ToString();
return -1;
}
// 5 打印结果
float y_value = *(y[0].tensor<float, 0>().data());
LOG(INFO) << "y = a*x+b\n" << " a(0.1) b(0) in graph, x(1), y(0.1):" << y_value;
return 0;
}
2. 加载tensorflow代码进行预测:(加载图、模型、预测)
链接: https://blog.csdn.net/heiheiya/article/details/89849737
链接:https://blog.csdn.net/yz2zcx/article/details/100406635
3. MNIST数据的C++代码案例
4. MLP模型的C++代码案例
5、 C++调用tensorflow (shape)
链接:https://blog.csdn.net/koibiki/article/details/88238170