使用Tensorflow C++ API编写预测代码
预测代码主要包括如下步骤:
1,创建Session
2,导入之间生成的模型
3,将模型设置到创建的Session里
4,设置模型输入输出
5,关闭Session
第一步,创建Session
//创建Session
Session* session;
Status status = NewSession(SessionOptions(),&session);
if(!status.ok()){
std::cout<<status.ToString()<<std::endl;
}
else{
std::cout<<"Session created successful"<<std::endl;
}
第二步,导入模型
//导入之前生成的模型
string model_path = "";
//从pb文件中读取图模型
GraphDef graphdef;
Status status_load = ReadBinaryProto(Env::Default(),model_path,&graphdef);
if(!status_load.ok()){
std::cout<<"ERROR: Loading model failed..."<<model_path<<std::endl;
std::c