系统环境,win7 64位系统,VS2017,tensorflow 1.8.0.
实现C++调用tensorflow训练好的模型文件pb文件,进行预测。采用指针赋值给Tensor的方式。
1. Tensorflow库
由于个人使用CMake编译Tensorflow源码,一直因为种种原因没通过,就懒得再编译了。所以在GitHub上下载了现成的库使用。网址如下:https://github.com/fo40225/tensorflow-windows-wheel/tree/master/1.8.0/cpp。
我使用的是CPU版本的avx,可以先使用CPU_Z检测电脑是否支持avx。然后我使用这个库是Release版本的,刚开始在Debug下使用总是报错。
库文件如下:
把bin文件的路径加到环境变量里。
其他两个include和lib文件路径,在VS2017的项目属性里加到包含目录和库目录中等配置。
2. C++调用tensorflow的模型文件pb文件进行预测
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/cc/saved_model/tag_constants.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/lib/core/errors.h"
using namespace std;
using namespace tensorflow;
using namespace tensorflow::ops;
using tensorflow::Tensor;
using tensorflow::Status;
using tensorflow::string;
using tensorflow::int32;
//使用指针给Tensor赋值,作为模型的输入
void Pointer_to_Tensor(string raw_path, Tensor* output_tensor, int row, int col,int iww, int iwl)
{
short* data = new short[row*col]{ 0 };
//将需要的内容放到指针中
//此处省略data指针的赋值,就是自己模型需要输入的数据。
float *p = output_tensor->flat< float >().data(); //创建一个指向tensor的内容的指针
memcpy(p, img_data, row*col * sizeof(float)); //使用指针地址的copy操作,将输入数据指针拷贝给Tensor指向的指针。
}
//将模型的数据存储到指针中
void Tensor_to_Pointer(Tensor input_tensor, float* & dstData, int row, int col)
{
//创建一个指向tensor的内容的指针
float *p = input_tensor.flat< float >().data();
memcpy(dstData, p, row*col * sizeof(float));
}
int main()
{
string model_path = ".\\model\\a1.pb";
string data_path = ".\\data\\7.jpg";
//建立会话
Session *session;
Status status = NewSession(SessionOptions(), &session);
if (!status.ok())
{
cout << status.ToString() << "\n";
return 1;
}
cout << "Session Successfully created!\n";
//加载模型
GraphDef graphdef; //Graph Definition for current model
Status status_load = ReadBinaryProto(Env::Default(), model_path, &graphdef); //从pb文件中读取图模型;
if (!status_load.ok())
{
std::cout << "ERROR: Loading model failed..." << model_path << std::endl;
std::cout << status_load.ToString() << "\n";
return -1;
}
Status status_create = session->Create(graphdef); //将模型导入会话Session中;
if (!status_create.ok())
{
std::cout << "ERROR: Creating graph in session failed..." << status_create.ToString() << std::endl;
return -1;
}
cout << "Session successfully created." << endl;
//准备数据
int row = 256;
int col = 256;
Tensor resized_tensor(DT_FLOAT, TensorShape({ 1,row,col,1 })); //创建一个tensor作为输入网络的接口
Pointer_to_Tensor(data_path, &resized_tensor, row, col, iww, iwl);
cout << resized_tensor.DebugString() << endl;
//运行网络进行预测
string input_tensor_name = "image";
string output_tensor_name = "model/Sigmoid";
vector<tensorflow::Tensor> outputs;
string output_node = output_tensor_name;
Status status_run = session->Run({ { input_tensor_name, resized_tensor } }, { output_node }, {}, &outputs);
if (!status_run.ok())
{
cout << "ERROR: RUN failed..." << std::endl;
cout << status_run.ToString() << "\n";
return -1;
}
//把输出值给提取出来
cout << "Output tensor size:" << outputs.size() << std::endl;
/*for (std::size_t i = 0; i < outputs.size(); i++)
{
cout << outputs[i].DebugString() << endl;
}*/
Tensor t = outputs[0]; // Fetch the first tensor
cout << "Tensor shape: " << t.shape() << endl; //[1,128,128,1]
float * dstData = new float[row*col]{ 0 };
Tensor_to_Pointer(t, dstData, row, col);
return 0;
}
以上是C++调用tensorflow权值进行预测并且使用指针赋值给Tensor的全部过程。