1、安装protobuf,使用protoc把tensorflow模型中的proto文件转换为对应的.h和.cc,如下图所示:
解析用到的proto位置大概在tensorflow\core\framework和tensorflow\core\protobuf,当然这里面的proto比较多,只要把graph.proto包含的proto文件用到就可以了。
proto里面有syntax = "proto3",这就意味着protobuf的版本需要在3.0以上。
2、搭建VS工程
主要信息获取可以通过以下代码获取:
// tensorflow_pb.cpp : 定义控制台应用程序的入口点。
//
#include "stdafx.h"
#include <iostream>
#include <fstream>
#include "graph.pb.h"
int main()
{
//char* pbname = "test.txt";
char* pbname = "test.pb";
std::fstream fp;
fp.open(pbname,std::ios::in | std::ios::binary);
if (!fp)
{
printf("file not found\n");
}
tensorflow::GraphDef graph_def;
bool suc = graph_def.ParseFromIstream(&fp);
fp.close();
int size = graph_def.node_size();
//yes
tensorflow::NodeDef node_info;
node_info = graph_def.node(8); //对应于python:oplist = get_operation();node_info = oplist[8];
size = node_info.input_size();
//std::string input0_name = node_info.input(0);
//std::string input1_name = node_info.input(1);
google::protobuf::Map< std::string, tensorflow::AttrValue> attr_map;
attr_map = node_info.attr(); //获取该结点的attr的map
/*if (attr_map.contains("transpose_a"))
{
printf("yes");
}
tensorflow::AttrValue attr_temp = attr_map["transpose_a"];*/
google::protobuf::Map< std::string, tensorflow::AttrValue>::iterator iter = attr_map.begin();
while (iter != attr_map.end())
{
std::cout << iter->first << std::endl; //查看attrmap的key值
iter++;
}
if (attr_map.contains("value"))
{
printf("yes");
}
tensorflow::AttrValue tensor_content = attr_map["value"];
suc = tensor_content.has_tensor();
tensorflow::TensorProto tensor_info = tensor_content.tensor();
std::string tmp = tensor_info.tensor_content();
float* data = (float*)tmp.data(); //如果该attr包含的tensor有值,需要将该值转换为float来用
tensorflow::TensorShapeProto tens_shape = tensor_info.tensor_shape();
size = tens_shape.dim_size();
tensorflow::TensorShapeProto_Dim tens_dim = tens_shape.dim(0);
int64_t size_zz = tens_dim.size(); //获取该Tensor的维度信息,这里获取shape的第一维信息
return 0;
}
当然,在VS工程里需要将使用的protobuf里的src\google\protobuf这部分内容拉到工程里面,为了跑通程序,还需要将上述protobuf的名字包含test的文件去除,这些文件需要另外包含gmock等依赖;
3、最初打算分析tensorflow源码,但是后来感觉之前python的解析脚本用到的指令应该是tensorflow对上述代码的封装,就拿get_operation(tensor)_by_name为例,tensorflow内部会维护一个op结点(对应上述代码的node)及其name的字典;为了不依赖tensorflow,只好从protobuf解析来做了~
大致梳理一下:首先需要区分const节点和非const节点,然后将const节点中能计算的信息全部计算出来,便于后续直接获取;
然后根据数据流将所有节点进行排序;最后根据需求将可合并的节点进行合并,并解析;