【TensorFlow源码系列】【零】使用TensorFlow C++ 接口进行模型推理

#include <string>
#include <vector>
#include <iostream>
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/framework/tensor.h"

//using namespace std;
//using namespace tensorflow;

int main(int argc,char **argv)
{
	// 1. 创建session
	Session * session;
	Status status = NewSession(SessionOptions(),&session);
    
	// 2. 模型路径
	string model_path = "mnist.pb";
	
	// 3. 将pb原始模型导入到GraphDef中
	GraphDef graphdef;
    status = ReadBinaryProto(Env::Default(),model_path,&graphdef);
	
	if(!status.ok()){
		
		return 0;
	}
	
	// 4. 将原始模型加载到session中
	status = session->Create(graphdef);
	
	if(!status.ok()){
		
		return 0;
	}
	
	// 5. 创建输入输出tensor
	std::vector<std::pair<std::string,tensorflow::Tensor>> inputs;
	std::vector<tensorflow::Tensor> outputs;
	
	tensorflow::Tensor input_tensor(DT_FLOAT,tensorflow::TesorShape({1,28,28,1}));
	
	// 6. 获取输入tensor指针,向里面填写数据
	auto plane_tensor = input_tensor.tensor<float,4>();
	
	for(int n = 0; n < 1; ++n)
		for(int h = 0 ; h < 28; ++h)
			for(int w = 0; w < 28; ++w)
				for(int c = 0; c < 1; ++c){
					plane_tensor(n,h,w,c) = 1.0f;
				}
	inputs.push_back({"inputs",input_tensor});
	
	// 7. 运行模型,需要传递输入tensor,输出tensor,输出tensor的name ---softmax
	status = session->Run(inputs,{"softmax"},{},&outputs);
	if(!status.ok()){
		
		return 0;
	}
	
	// 8. 计算完成后,将计算结果从outputtensor中取出来
	auto out_tensor = out_tensor[0].tensor<float,2>();
	for(int n = 0; n < 1; ++n)
		for(int h = 0 ; h < 10; ++h){
					std::cout<<out_tensor(n,h)<<std::endl;
				}
	
	return 0;
}

后续源码分析,会基于这个主体流程作分析。

转载于:https://my.oschina.net/u/3800567/blog/2243794

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值