上一篇博客讲到了scope,在TF的C++API中,所有的operation的第一个参数都是scope。其实C++的tf api有很多不一样的特性。我们往后慢慢学习,慢慢总结。
OK,这次就写篇实际的,也是这两天很折腾人的环节。我们得到了Pb文件,想在C++端使用,以便后续封装生成dll。那么如何导入Pb模型呢?TF还是很友好的,提供了ReadBinaryProto导入pb模型。
这里我写了一个函数,大家可以直接拿去用:
Status readPb(Session *sess, GraphDef &gdef, const string &modelPath)
{
Status status = ReadBinaryProto(Env::Default(), modelPath, &gdef);
if (!status.ok()) {
std::cerr << status.ToString() << std::endl;
}
else
{
cout << "load graph protobuf successfully" << std::endl;
}
status = sess->Create(gdef);
return status;
}
需要包含以下的头文件:
#include "tensorflow\core\public\session.h"
#include "tensorflow\core\platform\env.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor.h"
using namespace tensorflow;
using std::cin;
using std::cout;
using std::vector;
自定义的readPb函数需要一个session类型的指针,一个GraphDef的引用类型,还有string类型,内容是pb文件的地址。
下面我们看看主函数怎么写:
Session * session = 0;
Status status = NewSession(SessionOptions(), &session);
Scope root = Scope::NewRootScope();
GraphDef gdef;
status = readPb(session, gdef, "flowNet.pb");
if (!status.ok()) {
std::cerr << status.ToString() << std::endl;
}
这样就可以调用pb文件了。Status类型可以帮助我们检测操作是否成功,如果成功,Status的成员函数ok()为True。如果操作不成功,我们可以使用ToString函数成员来输出未能导入图的问题。
常常是sess->Create这里会有问题。比如我就遇到过提示LeakyRelu这个节点找不到,原来是因为python中支持的tf.nn.leaky_relu在C++的API中不存在。遇到这种情况,只能去把python的leaky_relu换成使用普通函数写成的等价形式。
下一章我们会将如何导入预测图像。