使用protobuf解析tensorflow pb模型

1、安装protobuf,使用protoc把tensorflow模型中的proto文件转换为对应的.h和.cc,如下图所示:

解析用到的proto位置大概在tensorflow\core\framework和tensorflow\core\protobuf,当然这里面的proto比较多,只要把graph.proto包含的proto文件用到就可以了。

proto里面有syntax = "proto3",这就意味着protobuf的版本需要在3.0以上。

2、搭建VS工程

主要信息获取可以通过以下代码获取:

// tensorflow_pb.cpp : 定义控制台应用程序的入口点。
//

#include "stdafx.h"
#include <iostream>
#include <fstream>

#include "graph.pb.h"

int main()
{
	//char* pbname = "test.txt";
	char* pbname = "test.pb";
	std::fstream fp;
	fp.open(pbname,std::ios::in | std::ios::binary);
	if (!fp)
	{
		printf("file not found\n");
	}
	tensorflow::GraphDef graph_def;
	bool suc = graph_def.ParseFromIstream(&fp);	
	fp.close();
	int size = graph_def.node_size();
	//yes
	tensorflow::NodeDef node_info;
	node_info = graph_def.node(8);   //对应于python:oplist = get_operation();node_info = oplist[8];

	size = node_info.input_size();	

	//std::string input0_name = node_info.input(0);
	//std::string input1_name = node_info.input(1);

	google::protobuf::Map< std::string, tensorflow::AttrValue> attr_map;
	attr_map = node_info.attr();	//获取该结点的attr的map
	/*if (attr_map.contains("transpose_a"))
	{
		printf("yes");
	}
	tensorflow::AttrValue attr_temp = attr_map["transpose_a"];*/

	google::protobuf::Map< std::string, tensorflow::AttrValue>::iterator iter = attr_map.begin();
	while (iter != attr_map.end())
	{
		std::cout << iter->first << std::endl; //查看attrmap的key值
		iter++;
	}
	if (attr_map.contains("value"))
	{
		printf("yes");
	}
	tensorflow::AttrValue tensor_content = attr_map["value"];
	suc = tensor_content.has_tensor();
	tensorflow::TensorProto tensor_info = tensor_content.tensor();
	std::string tmp = tensor_info.tensor_content();
	float* data = (float*)tmp.data();	//如果该attr包含的tensor有值,需要将该值转换为float来用
	tensorflow::TensorShapeProto tens_shape = tensor_info.tensor_shape();
	size = tens_shape.dim_size();
	tensorflow::TensorShapeProto_Dim tens_dim = tens_shape.dim(0);
	int64_t size_zz = tens_dim.size(); //获取该Tensor的维度信息,这里获取shape的第一维信息
    return 0;
}

当然,在VS工程里需要将使用的protobuf里的src\google\protobuf这部分内容拉到工程里面,为了跑通程序,还需要将上述protobuf的名字包含test的文件去除,这些文件需要另外包含gmock等依赖;

3、最初打算分析tensorflow源码,但是后来感觉之前python的解析脚本用到的指令应该是tensorflow对上述代码的封装,就拿get_operation(tensor)_by_name为例,tensorflow内部会维护一个op结点(对应上述代码的node)及其name的字典;为了不依赖tensorflow,只好从protobuf解析来做了~

大致梳理一下:首先需要区分const节点和非const节点,然后将const节点中能计算的信息全部计算出来,便于后续直接获取;

然后根据数据流将所有节点进行排序;最后根据需求将可合并的节点进行合并,并解析;

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值