python怎么使用训练好的模型设计_python tensorflow训练好的模型怎么在c++用

展开全部

// 导入之2113前已经保存好的模型

// 本程序来自5261tensorflow/c/c_api_test.cc

// 如果不明白4102,就看这个测试脚本就行1653了

const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123";

const string saved_model_dir = tensorflow::io::JoinPath(

tensorflow::testing::TensorFlowSrcRoot(), kSavedModel);

TF_SessionOptions* opt = TF_NewSessionOptions();

TF_Buffer* run_options = TF_NewBufferFromString("", 0);

TF_Buffer* metagraph = TF_NewBuffer();

TF_Status* s = TF_NewStatus();

const char* tags[] = {tensorflow::kSavedModelTagServe};

TF_Graph* graph = TF_NewGraph();

TF_Session* session = TF_LoadSessionFromSavedModel(

opt, run_options, saved_model_dir.c_str(), tags, 1, graph, metagraph, s);

TF_DeleteBuffer(run_options);

TF_DeleteSessionOptions(opt);

tensorflow::MetaGraphDef metagraph_def;

metagraph_def.ParseFromArray(metagraph->data, metagraph->length);

TF_DeleteBuffer(metagraph);

EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);

CSession csession(session);

// Retrieve the regression signature from meta graph def.

const auto signature_def_map = metagraph_def.signature_def();

const auto signature_def = signature_def_map.at("regress_x_to_y");

const string input_name =

signature_def.inputs().at(tensorflow::kRegressInputs).name();

const string output_name =

signature_def.outputs().at(tensorflow::kRegressOutputs).name();

// Write {0, 1, 2, 3} as tensorflow::Example inputs.

Tensor input(tensorflow::DT_STRING, TensorShape({4}));

for (tensorflow::int64 i = 0; i < input.NumElements(); ++i) {

tensorflow::Example example;

auto* feature_map = example.mutable_features()->mutable_feature();

(*feature_map)["x"].mutable_float_list()->add_value(i);

input.flat()(i) = example.SerializeAsString();

}

const tensorflow::string input_op_name =

tensorflow::ParseTensorName(input_name).first.ToString();

TF_Operation* input_op =

TF_GraphOperationByName(graph, input_op_name.c_str());

ASSERT_TRUE(input_op != nullptr);

csession.SetInputs({{input_op, TF_Tensor_EncodeStrings(input)}});

const tensorflow::string output_op_name =

tensorflow::ParseTensorName(output_name).first.ToString();

TF_Operation* output_op =

TF_GraphOperationByName(graph, output_op_name.c_str());

ASSERT_TRUE(output_op != nullptr);

csession.SetOutputs({output_op});

csession.Run(s);

ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);

TF_Tensor* out = csession.output_tensor(0);

ASSERT_TRUE(out != nullptr);

EXPECT_EQ(TF_FLOAT, TF_TensorType(out));

EXPECT_EQ(2, TF_NumDims(out));

EXPECT_EQ(4, TF_Dim(out, 0));

EXPECT_EQ(1, TF_Dim(out, 1));

float* values = static_cast(TF_TensorData(out));

// These values are defined to be (input / 2) + 2.

EXPECT_EQ(2, values[0]);

EXPECT_EQ(2.5, values[1]);

EXPECT_EQ(3, values[2]);

EXPECT_EQ(3.5, values[3]);

csession.CloseAndDelete(s);

EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);

TF_DeleteGraph(graph);

TF_DeleteStatus(s);

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值