前述博文 Tensorflow C++ 从训练到部署(2):简单图的保存、读取与 CMake 编译 和 Tensorflow C++ 从训练到部署(3):使用 Keras 训练和部署 CNN 使用 Tensorflow/Keras 的 Python API 进行训练,并使用 C++ API 进行了预测。由于 C++ API 需要编译 Tensorflow 源码,还是比较麻烦的。而 Tensorflow 官方提供了 C API 编译好的库文件,相对部署上比较容易(直接复制库文件到自己的工程即可),本文将介绍使用 C API 进行预测的方法。对于 Python 训练部分,与前述文章相同不做赘述。
0、系统环境
Ubuntu 16.04
Tensorflow 1.12.0
1、安装依赖
1、GPU 支持安装(可选)
CUDA 9.0
cnDNN 7.x
其中 1.12.0 的下载地址如下(我这里提供了包含TX2 aarch64在内的几个版本):
将库解压到 third_party/libtensorflow 目录。
如果上面的版本都不符合你的需求,你可以参照这篇文章编译你需要的版本。
2、TFUtils 工具类
为了简便起见,我们首先将常用的 C API 封装为
1)文件 utils/TFUtils.hpp:
C++
// Licensed under the MIT License .
// Copyright (c) 2018 Liu Xiao and Daniil Goncharov .
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
#pragma once
#if defined(_MSC_VER)
# if !defined(COMPILER_MSVC)
# define COMPILER_MSVC // Set MSVC visibility of exported symbols in the shared library.
# endif
# pragma warning(push)
# pragma warning(disable : 4190)
#endif
#include
#include
#include
#include
#include
#include
class TFUtils {
public:
enum STATUS
{
SUCCESS = 0,
SESSION_CREATE_FAILED = 1,
MODEL_LOAD_FAILED = 2,
FAILED_RUN_SESSION = 3,
MODEL_NOT_LOADED = 4,
};
TFUtils();
STATUS LoadModel(std::string model_file);
~TFUtils();
TF_Output GetOperationByName(std::string name, int idx);
STATUS RunSession(const std::vector& inputs, const std::vector& input_tensors,
const std::vector& outputs, std::vector& output_tensors);
// Static functions
template
static TF_Tensor* CreateTensor(TF_DataType data_type,
const std::vector<:int64_t>& dims,
const std::vector& data) {
return CreateTensor(data_type,
dims.data(), dims.size(),
data.data(), data.size() * sizeof(T));
}
static void DeleteTensor(TF_Tensor* tensor);
static void DeleteTensors(const std::vector& tensors);
template
static std::vector<:vector>> GetTensorsData(const std::vector& tensors) {
std::vector<:vector>> data;
data.reserve(tensors.size());
for (const auto t : tensors) {
data.push_back(GetTensorData(t));
}
return data;
}
static TF_Tensor* CreateTensor(TF_DataType data_type,
const std::int64_t* dims, std::size_t num_dims,
const void* data, std::size_t len);
template
static std::vector GetTensorData(const TF_Tensor* tensor) {
const auto data = static_cast(TF_TensorData(tensor));
if (data == nullptr) {
return {};
}
return {data, data + (TF_TensorByteSize(tensor) / TF_DataTypeSize(TF_TensorType(tensor)))};
}
// STATUS GetErrorCode();
static void PrinStatus(STATUS status);
private:
TF_Graph* graph_def;
TF_Session* sess;
STATUS init_error_code;
private:
TF_Graph* LoadGraphDef(const char* file);
TF_Session* CreateSession(TF_Graph* graph);
bool CloseAndDeleteSession(TF_Session* sess);
bool RunSession(TF_Session* sess,
const TF_Output* inputs, TF_Tensor* const* input_tensors, std::size_t ninputs,
const TF_Output* outputs, TF_Tensor** output_tensors, std::size_t noutputs);
bool RunSession(TF_Session* sess,
const std::vector& inputs, const std::vector& input_tensors,
const std::vector& outputs, std::vector& output_tensors);
}; // End class TFUtils
#if defined(_MSC_VER)
# pragma warning(pop)
#endif
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125// Licensed under the MIT License .
// Copyright (c) 2018 Liu Xiao and Daniil Goncharov .
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
#pragma once
#if defined(_MSC_VER)
# if !defined(COMPILER_MSVC)
# define COMPILER_MSVC // Set MSVC visibility of exported symbols in the shared library.
# endif
# pragma warning(push)
# pragma warning(disable : 4190)
#endif
#include
#include
#include
#include
#include
#include
classTFUtils{
public:
enumSTATUS
{
SUCCESS=0,
SESSION_CREATE_FAILED=1,
MODEL_LOAD_FAILED=2,
FAILED_RUN_SESSION=3,
MODEL_NOT_LOADED=4,
};
TFUtils();
STATUSLoadModel(std::stringmodel_file);
~TFUtils();
TF_OutputGetOperationByName(std::stringname,intidx);
STATUSRunSession(conststd::vector&inputs,conststd::vector&input_tensors,
conststd::vector&outputs,std::vector&output_tensors);
// Static functions
template
staticTF_Tensor*CreateTensor(TF_DataTypedata_type,
conststd::vector<:int64_t>&dims,
conststd::vector&data){
returnCreateTensor(data_type,
dims.data(),dims.size(),
data.data(),data.size()*sizeof(T));
}
staticvoidDeleteTensor(TF_Tensor*tensor);
staticvoidDeleteTensors(conststd::vector&tensors);
template
staticstd::vector<:vector>>GetTensorsData(conststd::vector&tensors){
std::vector<:vector>>data;
data.reserve(tensors.size());
for(constautot:tensors){
data.push_back(GetTensorData(t));
}
returndata;
}
staticTF_Tensor*CreateTensor(TF_DataTypedata_type,
conststd::int64_t*dims,std::size_tnum_dims,
constvoid*data,std::size_tlen);
template
staticstd::vectorGetTensorData(constTF_Tensor*tensor){
constautodata=static_cast(TF_TensorData(tensor));
if(data==nullptr){
return{};
}
return{data,data+(TF_TensorByteSize(tensor)/TF_DataTypeSize(TF_TensorType(tensor)))};
}
// STATUS GetErrorCode();
staticvoidPrinStatus(STATUSstatus);
private:
TF_Graph*graph_def;
TF_Session*sess;
STATUSinit_error_code;
private:
TF_Graph*LoadGraphDef(constchar*file);
TF_Session*CreateSession(TF_Graph*graph);
boolCloseAndDeleteSession(TF_Session*sess);
boolRunSession(TF_Session*sess,
constTF_Output*inputs,TF_Tensor*const*input_tensors,std::size_tninputs,
constTF_Output*outputs,TF_Tensor**output_tensors,std::size_tnoutputs);
boolRunSession(TF_Session*sess,
conststd::vector&inputs,conststd::vector&input_tensors,
conststd::vector&outputs,std::vector&output_tensors);
};// End class TFUtils
#if defined(_MSC_VER)
# pragma warning(pop)
#endif
2)文件 utils/TFUtils.cpp:
C++
// Licensed under the MIT License .
// Copyright (c) 2018 Liu Xiao and Daniil Goncharov .
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
#if defined(_MSC_VER)
# pragma warning(push)
# pragma warning(disable : 4996)
#endif
#include "TFUtils.hpp"
#include
#include
#include
#include
#include
#include
// Public functions
TFUtils::TFUtils()
{
init_error_code = MODEL_NOT_LOADED;
}
// Public functions
TFUtils::STATUS TFUtils::LoadModel(std::string model_file)
{
// Load graph
graph_def = LoadGraphDef(model_file.c_str());
if(graph_def == nullptr){
std::cerr << "loading model failed ......" << std::endl;
init_error_code = MODEL_LOAD_FAILED;
return MODEL_LOAD_FAILED;
}
// Create session
sess = CreateSession(graph_def);
if(sess == nullptr){
init_error_code = SESSION_CREATE_FAILED;
std::cerr << "create sess failed ......" << std::endl;
return SESSION_CREATE_FAILED;
}
init_error_code = SUCCESS;
return init_error_code;
}
TFUtils::~TFUtils()
{
if (sess)
CloseAndDeleteSession(sess);
if (graph_def)
TF_DeleteGraph(graph_def);
}
TF_Output TFUtils::GetOperationByName(std::string name, int idx) {
return {TF_GraphOperationByName(graph_def, name.c_str()), idx};
}
TFUtils::STATUS TFUtils::RunSession(const std::vector& inputs, const std::vector& input_tensors,
const std::vector& outputs, std::vector& output_tensors)
{
if (init_error_code != SUCCESS)
return init_error_code;
bool run_ret = RunSession(sess, inputs, input_tensors, outputs, output_tensors);
if (run_ret == false)
return FAILED_RUN_SESSION;
return SUCCESS;
}
void TFUtils::PrinStatus(STATUS status)
{
switch(status) {
case SUCCESS:
std::cout << "status = SUCCESS" << std::endl;
break;
case SESSION_CREATE_FAILED:
std::cout << "status = SESSION_CREATE_FAILED" << std::endl;
break;
case MODEL_LOAD_FAILED:
std::cout << "status = MODEL_LOAD_FAILED" << std::endl;
break;
case FAILED_RUN_SESSION:
std::cout << "status = FAILED_RUN_SESSION" << std::endl;
break;
case MODEL_NOT_LOADED:
std::cout << "status = MODEL_NOT_LOADED" << std::endl;
break;
default:
std::cout << "status = NOT FOUND" << std::endl;
}
}
// Static functions
static void DeallocateBuffer(void* data, size_t) {
std::free(data);
}
static TF_Buffer* ReadBufferFromFile(const char* file) {
const auto f = std::fopen(file, "rb");
if (f == nullptr) {
return nullptr;
}
std::fseek(f, 0, SEEK_END);
const auto fsize = ftell(f);
std::fseek(f, 0, SEEK_SET);
if (fsize < 1) {
std::fclose(f);
return nullptr;
}
const auto data = std::malloc(fsize);
std::fread(data, fsize, 1, f);
std::fclose(f);
TF_Buffer* buf = TF_NewBuffer();
buf->data = data;
buf->length = fsize;
buf->data_deallocator = DeallocateBuffer;
return buf;
}
// Private functions
TF_Graph* TFUtils::LoadGraphDef(const char* file) {
if (file == nullptr) {
return nullptr;
}
TF_Buffer* buffer = ReadBufferFromFile(file);
if (buffer == nullptr) {
return nullptr;
}
TF_Graph* graph = TF_NewGraph();
TF_Status* status = TF_NewStatus();
TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
TF_GraphImportGraphDef(graph, buffer, opts, status);
TF_DeleteImportGraphDefOptions(opts);
TF_DeleteBuffer(buffer);
if (TF_GetCode(status) != TF_OK) {
TF_DeleteGraph(graph);
graph = nullptr;
}
TF_DeleteStatus(status);
return graph;
}
TF_Session* TFUtils::CreateSession(TF_Graph* graph) {
TF_Status* status = TF_NewStatus();
TF_SessionOptions* options = TF_NewSessionOptions();
TF_Session* sess = TF_NewSession(graph, options, status);
TF_DeleteSessionOptions(options);
if (TF_GetCode(status) != TF_OK) {
TF_DeleteStatus(status);
return nullptr;
}
return sess;
}
bool TFUtils::CloseAndDeleteSession(TF_Session* sess) {
TF_Status* status = TF_NewStatus();
TF_CloseSession(sess, status);
if (TF_GetCode(status) != TF_OK) {
TF_CloseSession(sess, status);
TF_DeleteSession(sess, status);
TF_DeleteStatus(status);
return false;
}
TF_DeleteSession(sess, status);
if (TF_GetCode(status) != TF_OK) {
TF_DeleteStatus(status);
return false;
}
TF_DeleteStatus(status);
return true;
}
bool TFUtils::RunSession(TF_Session* sess,
const TF_Output* inputs, TF_Tensor* const* input_tensors, std::size_t ninputs,
const TF_Output* outputs, TF_Tensor** output_tensors, std::size_t noutputs) {
if (sess == nullptr ||
inputs == nullptr || input_tensors == nullptr ||
outputs == nullptr || output_tensors == nullptr) {
return false;
}
TF_Status* status = TF_NewStatus();
TF_SessionRun(sess,
nullptr, // Run options.
inputs, input_tensors, static_cast(ninputs), // Input tensors, input tensor values, number of inputs.
outputs, output_tensors, static_cast(noutputs), // Output tensors, output tensor values, number of outputs.
nullptr, 0, // Target operations, number of targets.
nullptr, // Run metadata.
status // Output status.
);
if (TF_GetCode(status) != TF_OK) {
TF_DeleteStatus(status);
return false;
}
TF_DeleteStatus(status);
return true;
}
bool TFUtils::RunSession(TF_Session* sess,
const std::vector& inputs, const std::vector& input_tensors,
const std::vector& outputs, std::vector& output_tensors) {
return RunSession(sess,
inputs.data(), input_tensors.data(), input_tensors.size(),
outputs.data(), output_tensors.data(), output_tensors.size());
}
TF_Tensor* TFUtils::CreateTensor(TF_DataType data_type,
const std::int64_t* dims, std::size_t num_dims,
const void* data, std::size_t len) {
if (dims == nullptr || data == nullptr) {
return nullptr;
}
TF_Tensor* tensor = TF_AllocateTensor(data_type, dims, static_cast(num_dims), len);
if (tensor == nullptr) {
return nullptr;
}
void* tensor_data = TF_TensorData(tensor);
if (tensor_data == nullptr) {
TF_DeleteTensor(tensor);
return nullptr;
}
std::memcpy(tensor_data, data, std::min(len, TF_TensorByteSize(tensor)));
return tensor;
}
void TFUtils::DeleteTensor(TF_Tensor* tensor) {
if (tensor == nullptr) {
return;
}
TF_DeleteTensor(tensor);
}
void TFUtils::DeleteTensors(const std::vector& tensors) {
for (auto t : tensors) {
TF_DeleteTensor(t);
}
}
#if defined(_MSC_VER)
# pragma warning(pop)
#endif
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288// Licensed under the MIT License .
// Copyright (c) 2018 Liu Xiao and Daniil Goncharov .
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
#if defined(_MSC_VER)
# pragma warning(push)
# pragma warning(disable : 4996)
#endif
#include "TFUtils.hpp"
#include
#include
#include
#include
#include
#include
// Public functions
TFUtils::TFUtils()
{
init_error_code=MODEL_NOT_LOADED;
}
// Public functions
TFUtils::STATUSTFUtils::LoadModel(std::stringmodel_file)
{
// Load graph
graph_def=LoadGraphDef(model_file.c_str());
if(graph_def==nullptr){
std::cerr<
init_error_code=MODEL_LOAD_FAILED;
returnMODEL_LOAD_FAILED;
}
// Create session
sess=CreateSession(graph_def);
if(sess==nullptr){
init_error_code=SESSION_CREATE_FAILED;
std::cerr<
returnSESSION_CREATE_FAILED;
}
init_error_code=SUCCESS;
returninit_error_code;
}
TFUtils::~TFUtils()
{
if(sess)
CloseAndDeleteSession(sess);
if(graph_def)
TF_DeleteGraph(graph_def);
}
TF_OutputTFUtils::GetOperationByName(std::stringname,intidx){
return{TF_GraphOperationByName(graph_def,name.c_str()),idx};
}
TFUtils::STATUSTFUtils::RunSession(conststd::vector&inputs,conststd::vector&input_tensors,
conststd::vector&outputs,std::vector&output_tensors)
{
if(init_error_code!=SUCCESS)
returninit_error_code;
boolrun_ret=RunSession(sess,inputs,input_tensors,outputs,output_tensors);
if(run_ret==false)
returnFAILED_RUN_SESSION;
returnSUCCESS;
}
voidTFUtils::PrinStatus(STATUSstatus)
{
switch(status){
caseSUCCESS:
std::cout<
break;
caseSESSION_CREATE_FAILED:
std::cout<
break;
caseMODEL_LOAD_FAILED:
std::cout<
break;
caseFAILED_RUN_SESSION:
std::cout<
break;
caseMODEL_NOT_LOADED:
std::cout<
break;
default:
std::cout<
}
}
// Static functions
staticvoidDeallocateBuffer(void*data,size_t){
std::free(data);
}
staticTF_Buffer*ReadBufferFromFile(constchar*file){
constautof=std::fopen(file,"rb");
if(f==nullptr){
returnnullptr;
}
std::fseek(f,0,SEEK_END);
constautofsize=ftell(f);
std::fseek(f,0,SEEK_SET);
if(fsize<1){
std::fclose(f);
returnnullptr;
}
constautodata=std::malloc(fsize);
std::fread(data,fsize,1,f);
std::fclose(f);
TF_Buffer*buf=TF_NewBuffer();
buf->data=data;
buf->length=fsize;
buf->data_deallocator=DeallocateBuffer;
returnbuf;
}
// Private functions
TF_Graph*TFUtils::LoadGraphDef(constchar*file){
if(file==nullptr){
returnnullptr;
}
TF_Buffer*buffer=ReadBufferFromFile(file);
if(buffer==nullptr){
returnnullptr;
}
TF_Graph*graph=TF_NewGraph();
TF_Status*status=TF_NewStatus();
TF_ImportGraphDefOptions*opts=TF_NewImportGraphDefOptions();
TF_GraphImportGraphDef(graph,buffer,opts,status);
TF_DeleteImportGraphDefOptions(opts);
TF_DeleteBuffer(buffer);
if(TF_GetCode(status)!=TF_OK){
TF_DeleteGraph(graph);
graph=nullptr;
}
TF_DeleteStatus(status);
returngraph;
}
TF_Session*TFUtils::CreateSession(TF_Graph*graph){
TF_Status*status=TF_NewStatus();
TF_SessionOptions*options=TF_NewSessionOptions();
TF_Session*sess=TF_NewSession(graph,options,status);
TF_DeleteSessionOptions(options);
if(TF_GetCode(status)!=TF_OK){
TF_DeleteStatus(status);
returnnullptr;
}
returnsess;
}
boolTFUtils::CloseAndDeleteSession(TF_Session*sess){
TF_Status*status=TF_NewStatus();
TF_CloseSession(sess,status);
if(TF_GetCode(status)!=TF_OK){
TF_CloseSession(sess,status);
TF_DeleteSession(sess,status);
TF_DeleteStatus(status);
returnfalse;
}
TF_DeleteSession(sess,status);
if(TF_GetCode(status)!=TF_OK){
TF_DeleteStatus(status);
returnfalse;
}
TF_DeleteStatus(status);
returntrue;
}
boolTFUtils::RunSession(TF_Session*sess,
constTF_Output*inputs,TF_Tensor*const*input_tensors,std::size_tninputs,
constTF_Output*outputs,TF_Tensor**output_tensors,std::size_tnoutputs){
if(sess==nullptr||
inputs==nullptr||input_tensors==nullptr||
outputs==nullptr||output_tensors==nullptr){
returnfalse;
}
TF_Status*status=TF_NewStatus();
TF_SessionRun(sess,
nullptr,// Run options.
inputs,input_tensors,static_cast(ninputs),// Input tensors, input tensor values, number of inputs.
outputs,output_tensors,static_cast(noutputs),// Output tensors, output tensor values, number of outputs.
nullptr,0,// Target operations, number of targets.
nullptr,// Run metadata.
status// Output status.
);
if(TF_GetCode(status)!=TF_OK){
TF_DeleteStatus(status);
returnfalse;
}
TF_DeleteStatus(status);
returntrue;
}
boolTFUtils::RunSession(TF_Session*sess,
conststd::vector&inputs,conststd::vector&input_tensors,
conststd::vector&outputs,std::vector&output_tensors){
returnRunSession(sess,
inputs.data(),input_tensors.data(),input_tensors.size(),
outputs.data(),output_tensors.data(),output_tensors.size());
}
TF_Tensor*TFUtils::CreateTensor(TF_DataTypedata_type,
conststd::int64_t*dims,std::size_tnum_dims,
constvoid*data,std::size_tlen){
if(dims==nullptr||data==nullptr){
returnnullptr;
}
TF_Tensor*tensor=TF_AllocateTensor(data_type,dims,static_cast(num_dims),len);
if(tensor==nullptr){
returnnullptr;
}
void*tensor_data=TF_TensorData(tensor);
if(tensor_data==nullptr){
TF_DeleteTensor(tensor);
returnnullptr;
}
std::memcpy(tensor_data,data,std::min(len,TF_TensorByteSize(tensor)));
returntensor;
}
voidTFUtils::DeleteTensor(TF_Tensor*tensor){
if(tensor==nullptr){
return;
}
TF_DeleteTensor(tensor);
}
voidTFUtils::DeleteTensors(conststd::vector&tensors){
for(autot:tensors){
TF_DeleteTensor(t);
}
}
#if defined(_MSC_VER)
# pragma warning(pop)
#endif
3、简单图的读取与预测
在前述文章 Tensorflow C++ 从训练到部署(2):简单图的保存、读取与 CMake 编译 中我们已经介绍了一个 c=a*b 的简单“网络”是如何计算的。
其中 Python 构建网络和预测部分就不重复了,详见该文所述。这里直接给出 C API 的代码:
文件名:simple/load_simple_net_c_api.cc
C++
// Licensed under the MIT License .
// Copyright (c) 2018 Liu Xiao and Daniil Goncharov .
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
#include "../utils/TFUtils.hpp"
#include
#include
int main(int argc, char* argv[])
{
if (argc != 2)
{
std::cerr << std::endl << "Usage: ./project path_to_graph.pb" << std::endl;
return 1;
}
std::string graph_path = argv[1];
// TFUtils init
TFUtils TFU;
TFUtils::STATUS status = TFU.LoadModel(graph_path);
if (status != TFUtils::SUCCESS) {
std::cerr << "Can't load graph" << std::endl;
return 1;
}
// Input Tensor Create
const std::vector<:int64_t> input_a_dims = {1, 1};
const std::vector input_a_vals = {2.0};
const std::vector<:int64_t> input_b_dims = {1, 1};
const std::vector input_b_vals = {3.0};
const std::vector input_ops = {TFU.GetOperationByName("a", 0),
TFU.GetOperationByName("b", 0)};
const std::vector input_tensors = {TFUtils::CreateTensor(TF_FLOAT, input_a_dims, input_a_vals),
TFUtils::CreateTensor(TF_FLOAT, input_b_dims, input_b_vals)};
// Output Tensor Create
const std::vector output_ops = {TFU.GetOperationByName("c", 0)};
std::vector output_tensors = {nullptr};
status = TFU.RunSession(input_ops, input_tensors,
output_ops, output_tensors);
TFUtils::PrinStatus(status);
if (status == TFUtils::SUCCESS) {
const std::vector<:vector>> data = TFUtils::GetTensorsData(output_tensors);
const std::vector result = data[0];
std::cout << "Output value: " << result[0] << std::endl;
} else {
std::cout << "Error run session";
return 2;
}
TFUtils::DeleteTensors(input_tensors);
TFUtils::DeleteTensors(output_tensors);
return 0;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81// Licensed under the MIT License .
// Copyright (c) 2018 Liu Xiao and Daniil Goncharov .
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
#include "../utils/TFUtils.hpp"
#include
#include
intmain(intargc,char*argv[])
{
if(argc!=2)
{
std::cerr<<:endl . path_to_graph.pb>
return1;
}
std::stringgraph_path=argv[1];
// TFUtils init
TFUtilsTFU;
TFUtils::STATUSstatus=TFU.LoadModel(graph_path);
if(status!=TFUtils::SUCCESS){
std::cerr<
return1;
}
// Input Tensor Create
conststd::vector<:int64_t>input_a_dims={1,1};
conststd::vectorinput_a_vals={2.0};
conststd::vector<:int64_t>input_b_dims={1,1};
conststd::vectorinput_b_vals={3.0};
conststd::vectorinput_ops={TFU.GetOperationByName("a",0),
TFU.GetOperationByName("b",0)};
conststd::vectorinput_tensors={TFUtils::CreateTensor(TF_FLOAT,input_a_dims,input_a_vals),
TFUtils::CreateTensor(TF_FLOAT,input_b_dims,input_b_vals)};
// Output Tensor Create
conststd::vectoroutput_ops={TFU.GetOperationByName("c",0)};
std::vectoroutput_tensors={nullptr};
status=TFU.RunSession(input_ops,input_tensors,
output_ops,output_tensors);
TFUtils::PrinStatus(status);
if(status==TFUtils::SUCCESS){
conststd::vector<:vector>>data=TFUtils::GetTensorsData(output_tensors);
conststd::vectorresult=data[0];
std::cout<
}else{
std::cout<
return2;
}
TFUtils::DeleteTensors(input_tensors);
TFUtils::DeleteTensors(output_tensors);
return0;
}
简单解释一下:
C++
TFUtils::STATUS status = TFU.LoadModel(graph_path);
1TFUtils::STATUSstatus=TFU.LoadModel(graph_path);
这一行是加载 pb 文件。
C
// Input Tensor Create
const std::vector<:int64_t> input_a_dims = {1, 1};
const std::vector input_a_vals = {2.0};
const std::vector<:int64_t> input_b_dims = {1, 1};
const std::vector input_b_vals = {3.0};
const std::vector input_ops = {TFU.GetOperationByName("a", 0),
TFU.GetOperationByName("b", 0)};
const std::vector input_tensors = {TFUtils::CreateTensor(TF_FLOAT, input_a_dims, input_a_vals),
TFUtils::CreateTensor(TF_FLOAT, input_b_dims, input_b_vals)};
// Output Tensor Create
const std::vector output_ops = {TFU.GetOperationByName("c", 0)};
std::vector output_tensors = {nullptr}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16// Input Tensor Create
conststd::vector<:int64_t>input_a_dims={1,1};
conststd::vectorinput_a_vals={2.0};
conststd::vector<:int64_t>input_b_dims={1,1};
conststd::vectorinput_b_vals={3.0};
conststd::vectorinput_ops={TFU.GetOperationByName("a",0),
TFU.GetOperationByName("b",0)};
conststd::vectorinput_tensors={TFUtils::CreateTensor(TF_FLOAT,input_a_dims,input_a_vals),
TFUtils::CreateTensor(TF_FLOAT,input_b_dims,input_b_vals)};
// Output Tensor Create
conststd::vectoroutput_ops={TFU.GetOperationByName("c",0)};
std::vectoroutput_tensors={nullptr}
这一段是创建两个输入 tensor 以及输入的 ops。注意这里的 CreateTensor 在后面都需要调用 DeleteTensors 进行内存释放。输出的 tensors 还没创建先定义为 nullptr。
C++
status = TFU.RunSession(input_ops, input_tensors,
output_ops, output_tensors);
1
2status=TFU.RunSession(input_ops,input_tensors,
output_ops,output_tensors);
这一行是运行网络。
C++
const std::vector<:vector>> data = TFUtils::GetTensorsData(output_tensors);
const std::vector result = data[0];
1
2conststd::vector<:vector>>data=TFUtils::GetTensorsData(output_tensors);
conststd::vectorresult=data[0];
这两行是从输出的 output_tensors 读取数据到一个二维vector const std::vector>,我们这里输出只有 "c" 一个名字,而且只有一个索引 0,因此直接取出 data[0] 就是我们原本想要的输出。
编译运行这一文件,如果没有问题则会得到如下输出:
Shell
status = SUCCESS
Output value: 6
1
2status=SUCCESS
Outputvalue:6
4、CNN的读取与预测
与刚才小节3相似,CNN网络也是一样的流程,还是以最基本的 fashion_mnist 为例,该网络的训练和保存流程请参考之前的文章。这里我们仅介绍 C API 进行预测的部分。由于我们这里需要读取一幅图并转化成 Tensor 输入网络,我们构造一个简单的函数 Mat2Tensor 实现这一转换:
1)Met2Tensor 部分文件:fashion_mnist/utils/mat2tensor_c_cpi.h
C++
// Licensed under the MIT License .
// Copyright (c) 2018 Liu Xiao .
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
#ifndef TENSORFLOW_CPP_MAT2TENSOR_C_H
#define TENSORFLOW_CPP_MAT2TENSOR_C_H
#include
#include
#include
#include
#include "opencv2/core/core.hpp"
TF_Tensor* Mat2Tensor(cv::Mat &img, float normal = 1/255.0) {
const std::vector<:int64_t> input_dims = {1, img.size().height, img.size().width, img.channels()};
// Convert to float 32 and do normalize ops
cv::Mat fake_mat(img.rows, img.cols, CV_32FC(img.channels()));
img.convertTo(fake_mat, CV_32FC(img.channels()));
fake_mat *= normal;
TF_Tensor* image_input = TFUtils::CreateTensor(TF_FLOAT,
input_dims.data(), input_dims.size(),
fake_mat.data, (fake_mat.size().height * fake_mat.size().width * fake_mat.channels() * sizeof(float)));
return image_input;
}
#endif //TENSORFLOW_CPP_MAT2TENSOR_C_H
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49// Licensed under the MIT License .
// Copyright (c) 2018 Liu Xiao .
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
#ifndef TENSORFLOW_CPP_MAT2TENSOR_C_H
#define TENSORFLOW_CPP_MAT2TENSOR_C_H
#include
#include
#include
#include
#include "opencv2/core/core.hpp"
TF_Tensor*Mat2Tensor(cv::Mat&img,floatnormal=1/255.0){
conststd::vector<:int64_t>input_dims={1,img.size().height,img.size().width,img.channels()};
// Convert to float 32 and do normalize ops
cv::Matfake_mat(img.rows,img.cols,CV_32FC(img.channels()));
img.convertTo(fake_mat,CV_32FC(img.channels()));
fake_mat*=normal;
TF_Tensor*image_input=TFUtils::CreateTensor(TF_FLOAT,
input_dims.data(),input_dims.size(),
fake_mat.data,(fake_mat.size().height*fake_mat.size().width*fake_mat.channels()*sizeof(float)));
returnimage_input;
}
#endif //TENSORFLOW_CPP_MAT2TENSOR_C_H
2)网络读取与预测,这部分与刚才的小节3基本一样,就不做解释了:
C++
// Licensed under the MIT License .
// Copyright (c) 2018 Liu Xiao and Daniil Goncharov .
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
#include "../utils/TFUtils.hpp"
#include "utils/mat2tensor_c_cpi.h"
#include
#include
// OpenCV
#include
#include
//std::string class_names[10] = {'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'};
std::string class_names[] = {"T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"};
int ArgMax(const std::vector result);
int main(int argc, char* argv[])
{
if (argc != 3)
{
std::cerr << std::endl << "Usage: ./project path_to_graph.pb path_to_image.png" << std::endl;
return 1;
}
// Load graph
std::string graph_path = argv[1];
// TFUtils init
TFUtils TFU;
TFUtils::STATUS status = TFU.LoadModel(graph_path);
if (status != TFUtils::SUCCESS) {
std::cerr << "Can't load graph" << std::endl;
return 1;
}
// Load image and convert to tensor
std::string image_path = argv[2];
cv::Mat image = cv::imread(image_path, CV_LOAD_IMAGE_GRAYSCALE);
const std::vector<:int64_t> input_dims = {1, image.size().height, image.size().width, image.channels()};
TF_Tensor* input_image = Mat2Tensor(image, 1/255.0);
// Input Tensor/Ops Create
const std::vector input_tensors = {input_image};
const std::vector input_ops = {TFU.GetOperationByName("input_image_input", 0)};
// Output Tensor/Ops Create
const std::vector output_ops = {TFU.GetOperationByName("output_class/Softmax", 0)};
std::vector output_tensors = {nullptr};
status = TFU.RunSession(input_ops, input_tensors,
output_ops, output_tensors);
if (status == TFUtils::SUCCESS) {
const std::vector<:vector>> data = TFUtils::GetTensorsData(output_tensors);
const std::vector result = data[0];
int pred_index = ArgMax(result);
// Print test accuracy
printf("Predict: %d Label: %s", pred_index, class_names[pred_index].c_str());
} else {
std::cout << "Error run session";
return 2;
}
TFUtils::DeleteTensors(input_tensors);
TFUtils::DeleteTensors(output_tensors);
return 0;
}
int ArgMax(const std::vector result)
{
float max_value = -1.0;
int max_index = -1;
const long count = result.size();
for (int i = 0; i < count; ++i) {
const float value = result[i];
if (value > max_value) {
max_index = i;
max_value = value;
}
std::cout << "value[" << i << "] = " << value << std::endl;
}
return max_index;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112// Licensed under the MIT License .
// Copyright (c) 2018 Liu Xiao and Daniil Goncharov .
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
#include "../utils/TFUtils.hpp"
#include "utils/mat2tensor_c_cpi.h"
#include
#include
// OpenCV
#include
#include
//std::string class_names[10] = {'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'};
std::stringclass_names[]={"T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot"};
intArgMax(conststd::vectorresult);
intmain(intargc,char*argv[])
{
if(argc!=3)
{
std::cerr<<:endl . path_to_graph.pb path_to_image.png>
return1;
}
// Load graph
std::stringgraph_path=argv[1];
// TFUtils init
TFUtilsTFU;
TFUtils::STATUSstatus=TFU.LoadModel(graph_path);
if(status!=TFUtils::SUCCESS){
std::cerr<
return1;
}
// Load image and convert to tensor
std::stringimage_path=argv[2];
cv::Matimage=cv::imread(image_path,CV_LOAD_IMAGE_GRAYSCALE);
conststd::vector<:int64_t>input_dims={1,image.size().height,image.size().width,image.channels()};
TF_Tensor*input_image=Mat2Tensor(image,1/255.0);
// Input Tensor/Ops Create
conststd::vectorinput_tensors={input_image};
conststd::vectorinput_ops={TFU.GetOperationByName("input_image_input",0)};
// Output Tensor/Ops Create
conststd::vectoroutput_ops={TFU.GetOperationByName("output_class/Softmax",0)};
std::vectoroutput_tensors={nullptr};
status=TFU.RunSession(input_ops,input_tensors,
output_ops,output_tensors);
if(status==TFUtils::SUCCESS){
conststd::vector<:vector>>data=TFUtils::GetTensorsData(output_tensors);
conststd::vectorresult=data[0];
intpred_index=ArgMax(result);
// Print test accuracy
printf("Predict: %d Label: %s",pred_index,class_names[pred_index].c_str());
}else{
std::cout<
return2;
}
TFUtils::DeleteTensors(input_tensors);
TFUtils::DeleteTensors(output_tensors);
return0;
}
intArgMax(conststd::vectorresult)
{
floatmax_value=-1.0;
intmax_index=-1;
constlongcount=result.size();
for(inti=0;i
constfloatvalue=result[i];
if(value>max_value){
max_index=i;
max_value=value;
}
std::cout<
}
returnmax_index;
}
编译运行这一文件,如果没有问题,则会得到如下输出:
C
value[0] = 6.40457e-09
value[1] = 2.41816e-07
value[2] = 3.60118e-08
value[3] = 1.18324e-09
value[4] = 6.13108e-11
value[5] = 0.00021271
value[6] = 2.01991e-11
value[7] = 3.94614e-05
value[8] = 1.17029e-10
value[9] = 0.999748
Predict: 9 Label: Ankle boot
1
2
3
4
5
6
7
8
9
10
11value[0]=6.40457e-09
value[1]=2.41816e-07
value[2]=3.60118e-08
value[3]=1.18324e-09
value[4]=6.13108e-11
value[5]=0.00021271
value[6]=2.01991e-11
value[7]=3.94614e-05
value[8]=1.17029e-10
value[9]=0.999748
Predict:9Label:Ankleboot
到此,我们就完成了使用 C API 运行 Tensorflow Model 的流程。
Original content here is published under these license terms:X
License Type:Read Only
License Abstract:You may read the original content in the context in which it is published (at this web address). No other copying or use is permitted without written agreement from the author.