tensorflow用c语言,Tensorflow C API 从训练到部署:使用 C API 进行预测和部署

前述博文 Tensorflow C++ 从训练到部署(2):简单图的保存、读取与 CMake 编译 和 Tensorflow C++ 从训练到部署(3):使用 Keras 训练和部署 CNN 使用 Tensorflow/Keras 的 Python API 进行训练,并使用 C++ API 进行了预测。由于 C++ API 需要编译 Tensorflow 源码,还是比较麻烦的。而 Tensorflow 官方提供了 C API 编译好的库文件,相对部署上比较容易(直接复制库文件到自己的工程即可),本文将介绍使用 C API 进行预测的方法。对于 Python 训练部分,与前述文章相同不做赘述。

0、系统环境

Ubuntu 16.04

Tensorflow 1.12.0

1、安装依赖

1、GPU 支持安装(可选)

CUDA 9.0

cnDNN 7.x

其中 1.12.0 的下载地址如下(我这里提供了包含TX2 aarch64在内的几个版本):

将库解压到 third_party/libtensorflow 目录。

如果上面的版本都不符合你的需求,你可以参照这篇文章编译你需要的版本。

2、TFUtils 工具类

为了简便起见,我们首先将常用的 C API 封装为

1)文件 utils/TFUtils.hpp:

C++

// Licensed under the MIT License .

// Copyright (c) 2018 Liu Xiao and Daniil Goncharov .

//

// Permission is hereby granted, free of charge, to any person obtaining a copy

// of this software and associated documentation files (the "Software"), to deal

// in the Software without restriction, including without limitation the rights

// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell

// copies of the Software, and to permit persons to whom the Software is

// furnished to do so, subject to the following conditions:

//

// The above copyright notice and this permission notice shall be included in all

// copies or substantial portions of the Software.

//

// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR

// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,

// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE

// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER

// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,

// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE

// SOFTWARE.

#pragma once

#if defined(_MSC_VER)

# if !defined(COMPILER_MSVC)

# define COMPILER_MSVC // Set MSVC visibility of exported symbols in the shared library.

# endif

# pragma warning(push)

# pragma warning(disable : 4190)

#endif

#include

#include

#include

#include

#include

#include

class TFUtils {

public:

enum STATUS

{

SUCCESS = 0,

SESSION_CREATE_FAILED = 1,

MODEL_LOAD_FAILED = 2,

FAILED_RUN_SESSION = 3,

MODEL_NOT_LOADED = 4,

};

TFUtils();

STATUS LoadModel(std::string model_file);

~TFUtils();

TF_Output GetOperationByName(std::string name, int idx);

STATUS RunSession(const std::vector& inputs, const std::vector& input_tensors,

const std::vector& outputs, std::vector& output_tensors);

// Static functions

template

static TF_Tensor* CreateTensor(TF_DataType data_type,

const std::vector<:int64_t>& dims,

const std::vector& data) {

return CreateTensor(data_type,

dims.data(), dims.size(),

data.data(), data.size() * sizeof(T));

}

static void DeleteTensor(TF_Tensor* tensor);

static void DeleteTensors(const std::vector& tensors);

template

static std::vector<:vector>> GetTensorsData(const std::vector& tensors) {

std::vector<:vector>> data;

data.reserve(tensors.size());

for (const auto t : tensors) {

data.push_back(GetTensorData(t));

}

return data;

}

static TF_Tensor* CreateTensor(TF_DataType data_type,

const std::int64_t* dims, std::size_t num_dims,

const void* data, std::size_t len);

template

static std::vector GetTensorData(const TF_Tensor* tensor) {

const auto data = static_cast(TF_TensorData(tensor));

if (data == nullptr) {

return {};

}

return {data, data + (TF_TensorByteSize(tensor) / TF_DataTypeSize(TF_TensorType(tensor)))};

}

// STATUS GetErrorCode();

static void PrinStatus(STATUS status);

private:

TF_Graph* graph_def;

TF_Session* sess;

STATUS init_error_code;

private:

TF_Graph* LoadGraphDef(const char* file);

TF_Session* CreateSession(TF_Graph* graph);

bool CloseAndDeleteSession(TF_Session* sess);

bool RunSession(TF_Session* sess,

const TF_Output* inputs, TF_Tensor* const* input_tensors, std::size_t ninputs,

const TF_Output* outputs, TF_Tensor** output_tensors, std::size_t noutputs);

bool RunSession(TF_Session* sess,

const std::vector& inputs, const std::vector& input_tensors,

const std::vector& outputs, std::vector& output_tensors);

}; // End class TFUtils

#if defined(_MSC_VER)

# pragma warning(pop)

#endif

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125// Licensed under the MIT License .

// Copyright (c) 2018 Liu Xiao and Daniil Goncharov .

//

// Permission is hereby  granted, free of charge, to any  person obtaining a copy

// of this software and associated  documentation files (the "Software"), to deal

// in the Software  without restriction, including without  limitation the rights

// to  use, copy,  modify, merge,  publish, distribute,  sublicense, and/or  sell

// copies  of  the Software,  and  to  permit persons  to  whom  the Software  is

// furnished to do so, subject to the following conditions:

//

// The above copyright notice and this permission notice shall be included in all

// copies or substantial portions of the Software.

//

// THE SOFTWARE  IS PROVIDED "AS  IS", WITHOUT WARRANTY  OF ANY KIND,  EXPRESS OR

// IMPLIED,  INCLUDING BUT  NOT  LIMITED TO  THE  WARRANTIES OF  MERCHANTABILITY,

// FITNESS FOR  A PARTICULAR PURPOSE AND  NONINFRINGEMENT. IN NO EVENT  SHALL THE

// AUTHORS  OR COPYRIGHT  HOLDERS  BE  LIABLE FOR  ANY  CLAIM,  DAMAGES OR  OTHER

// LIABILITY, WHETHER IN AN ACTION OF  CONTRACT, TORT OR OTHERWISE, ARISING FROM,

// OUT OF OR IN CONNECTION WITH THE SOFTWARE  OR THE USE OR OTHER DEALINGS IN THE

// SOFTWARE.

#pragma once

#if defined(_MSC_VER)

#  if !defined(COMPILER_MSVC)

#    define COMPILER_MSVC // Set MSVC visibility of exported symbols in the shared library.

#  endif

#  pragma warning(push)

#  pragma warning(disable : 4190)

#endif

#include

#include

#include

#include

#include

#include

classTFUtils{

public:

enumSTATUS

{

SUCCESS=0,

SESSION_CREATE_FAILED=1,

MODEL_LOAD_FAILED=2,

FAILED_RUN_SESSION=3,

MODEL_NOT_LOADED=4,

};

TFUtils();

STATUSLoadModel(std::stringmodel_file);

~TFUtils();

TF_OutputGetOperationByName(std::stringname,intidx);

STATUSRunSession(conststd::vector&inputs,conststd::vector&input_tensors,

conststd::vector&outputs,std::vector&output_tensors);

// Static functions

template

staticTF_Tensor*CreateTensor(TF_DataTypedata_type,

conststd::vector<:int64_t>&dims,

conststd::vector&data){

returnCreateTensor(data_type,

dims.data(),dims.size(),

data.data(),data.size()*sizeof(T));

}

staticvoidDeleteTensor(TF_Tensor*tensor);

staticvoidDeleteTensors(conststd::vector&tensors);

template

staticstd::vector<:vector>>GetTensorsData(conststd::vector&tensors){

std::vector<:vector>>data;

data.reserve(tensors.size());

for(constautot:tensors){

data.push_back(GetTensorData(t));

}

returndata;

}

staticTF_Tensor*CreateTensor(TF_DataTypedata_type,

conststd::int64_t*dims,std::size_tnum_dims,

constvoid*data,std::size_tlen);

template

staticstd::vectorGetTensorData(constTF_Tensor*tensor){

constautodata=static_cast(TF_TensorData(tensor));

if(data==nullptr){

return{};

}

return{data,data+(TF_TensorByteSize(tensor)/TF_DataTypeSize(TF_TensorType(tensor)))};

}

//    STATUS GetErrorCode();

staticvoidPrinStatus(STATUSstatus);

private:

TF_Graph*graph_def;

TF_Session*sess;

STATUSinit_error_code;

private:

TF_Graph*LoadGraphDef(constchar*file);

TF_Session*CreateSession(TF_Graph*graph);

boolCloseAndDeleteSession(TF_Session*sess);

boolRunSession(TF_Session*sess,

constTF_Output*inputs,TF_Tensor*const*input_tensors,std::size_tninputs,

constTF_Output*outputs,TF_Tensor**output_tensors,std::size_tnoutputs);

boolRunSession(TF_Session*sess,

conststd::vector&inputs,conststd::vector&input_tensors,

conststd::vector&outputs,std::vector&output_tensors);

};// End class TFUtils

#if defined(_MSC_VER)

#  pragma warning(pop)

#endif

2)文件 utils/TFUtils.cpp:

C++

// Licensed under the MIT License .

// Copyright (c) 2018 Liu Xiao and Daniil Goncharov .

//

// Permission is hereby granted, free of charge, to any person obtaining a copy

// of this software and associated documentation files (the "Software"), to deal

// in the Software without restriction, including without limitation the rights

// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell

// copies of the Software, and to permit persons to whom the Software is

// furnished to do so, subject to the following conditions:

//

// The above copyright notice and this permission notice shall be included in all

// copies or substantial portions of the Software.

//

// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR

// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,

// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE

// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER

// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,

// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE

// SOFTWARE.

#if defined(_MSC_VER)

# pragma warning(push)

# pragma warning(disable : 4996)

#endif

#include "TFUtils.hpp"

#include

#include

#include

#include

#include

#include

// Public functions

TFUtils::TFUtils()

{

init_error_code = MODEL_NOT_LOADED;

}

// Public functions

TFUtils::STATUS TFUtils::LoadModel(std::string model_file)

{

// Load graph

graph_def = LoadGraphDef(model_file.c_str());

if(graph_def == nullptr){

std::cerr << "loading model failed ......" << std::endl;

init_error_code = MODEL_LOAD_FAILED;

return MODEL_LOAD_FAILED;

}

// Create session

sess = CreateSession(graph_def);

if(sess == nullptr){

init_error_code = SESSION_CREATE_FAILED;

std::cerr << "create sess failed ......" << std::endl;

return SESSION_CREATE_FAILED;

}

init_error_code = SUCCESS;

return init_error_code;

}

TFUtils::~TFUtils()

{

if (sess)

CloseAndDeleteSession(sess);

if (graph_def)

TF_DeleteGraph(graph_def);

}

TF_Output TFUtils::GetOperationByName(std::string name, int idx) {

return {TF_GraphOperationByName(graph_def, name.c_str()), idx};

}

TFUtils::STATUS TFUtils::RunSession(const std::vector& inputs, const std::vector& input_tensors,

const std::vector& outputs, std::vector& output_tensors)

{

if (init_error_code != SUCCESS)

return init_error_code;

bool run_ret = RunSession(sess, inputs, input_tensors, outputs, output_tensors);

if (run_ret == false)

return FAILED_RUN_SESSION;

return SUCCESS;

}

void TFUtils::PrinStatus(STATUS status)

{

switch(status) {

case SUCCESS:

std::cout << "status = SUCCESS" << std::endl;

break;

case SESSION_CREATE_FAILED:

std::cout << "status = SESSION_CREATE_FAILED" << std::endl;

break;

case MODEL_LOAD_FAILED:

std::cout << "status = MODEL_LOAD_FAILED" << std::endl;

break;

case FAILED_RUN_SESSION:

std::cout << "status = FAILED_RUN_SESSION" << std::endl;

break;

case MODEL_NOT_LOADED:

std::cout << "status = MODEL_NOT_LOADED" << std::endl;

break;

default:

std::cout << "status = NOT FOUND" << std::endl;

}

}

// Static functions

static void DeallocateBuffer(void* data, size_t) {

std::free(data);

}

static TF_Buffer* ReadBufferFromFile(const char* file) {

const auto f = std::fopen(file, "rb");

if (f == nullptr) {

return nullptr;

}

std::fseek(f, 0, SEEK_END);

const auto fsize = ftell(f);

std::fseek(f, 0, SEEK_SET);

if (fsize < 1) {

std::fclose(f);

return nullptr;

}

const auto data = std::malloc(fsize);

std::fread(data, fsize, 1, f);

std::fclose(f);

TF_Buffer* buf = TF_NewBuffer();

buf->data = data;

buf->length = fsize;

buf->data_deallocator = DeallocateBuffer;

return buf;

}

// Private functions

TF_Graph* TFUtils::LoadGraphDef(const char* file) {

if (file == nullptr) {

return nullptr;

}

TF_Buffer* buffer = ReadBufferFromFile(file);

if (buffer == nullptr) {

return nullptr;

}

TF_Graph* graph = TF_NewGraph();

TF_Status* status = TF_NewStatus();

TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();

TF_GraphImportGraphDef(graph, buffer, opts, status);

TF_DeleteImportGraphDefOptions(opts);

TF_DeleteBuffer(buffer);

if (TF_GetCode(status) != TF_OK) {

TF_DeleteGraph(graph);

graph = nullptr;

}

TF_DeleteStatus(status);

return graph;

}

TF_Session* TFUtils::CreateSession(TF_Graph* graph) {

TF_Status* status = TF_NewStatus();

TF_SessionOptions* options = TF_NewSessionOptions();

TF_Session* sess = TF_NewSession(graph, options, status);

TF_DeleteSessionOptions(options);

if (TF_GetCode(status) != TF_OK) {

TF_DeleteStatus(status);

return nullptr;

}

return sess;

}

bool TFUtils::CloseAndDeleteSession(TF_Session* sess) {

TF_Status* status = TF_NewStatus();

TF_CloseSession(sess, status);

if (TF_GetCode(status) != TF_OK) {

TF_CloseSession(sess, status);

TF_DeleteSession(sess, status);

TF_DeleteStatus(status);

return false;

}

TF_DeleteSession(sess, status);

if (TF_GetCode(status) != TF_OK) {

TF_DeleteStatus(status);

return false;

}

TF_DeleteStatus(status);

return true;

}

bool TFUtils::RunSession(TF_Session* sess,

const TF_Output* inputs, TF_Tensor* const* input_tensors, std::size_t ninputs,

const TF_Output* outputs, TF_Tensor** output_tensors, std::size_t noutputs) {

if (sess == nullptr ||

inputs == nullptr || input_tensors == nullptr ||

outputs == nullptr || output_tensors == nullptr) {

return false;

}

TF_Status* status = TF_NewStatus();

TF_SessionRun(sess,

nullptr, // Run options.

inputs, input_tensors, static_cast(ninputs), // Input tensors, input tensor values, number of inputs.

outputs, output_tensors, static_cast(noutputs), // Output tensors, output tensor values, number of outputs.

nullptr, 0, // Target operations, number of targets.

nullptr, // Run metadata.

status // Output status.

);

if (TF_GetCode(status) != TF_OK) {

TF_DeleteStatus(status);

return false;

}

TF_DeleteStatus(status);

return true;

}

bool TFUtils::RunSession(TF_Session* sess,

const std::vector& inputs, const std::vector& input_tensors,

const std::vector& outputs, std::vector& output_tensors) {

return RunSession(sess,

inputs.data(), input_tensors.data(), input_tensors.size(),

outputs.data(), output_tensors.data(), output_tensors.size());

}

TF_Tensor* TFUtils::CreateTensor(TF_DataType data_type,

const std::int64_t* dims, std::size_t num_dims,

const void* data, std::size_t len) {

if (dims == nullptr || data == nullptr) {

return nullptr;

}

TF_Tensor* tensor = TF_AllocateTensor(data_type, dims, static_cast(num_dims), len);

if (tensor == nullptr) {

return nullptr;

}

void* tensor_data = TF_TensorData(tensor);

if (tensor_data == nullptr) {

TF_DeleteTensor(tensor);

return nullptr;

}

std::memcpy(tensor_data, data, std::min(len, TF_TensorByteSize(tensor)));

return tensor;

}

void TFUtils::DeleteTensor(TF_Tensor* tensor) {

if (tensor == nullptr) {

return;

}

TF_DeleteTensor(tensor);

}

void TFUtils::DeleteTensors(const std::vector& tensors) {

for (auto t : tensors) {

TF_DeleteTensor(t);

}

}

#if defined(_MSC_VER)

# pragma warning(pop)

#endif

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

227

228

229

230

231

232

233

234

235

236

237

238

239

240

241

242

243

244

245

246

247

248

249

250

251

252

253

254

255

256

257

258

259

260

261

262

263

264

265

266

267

268

269

270

271

272

273

274

275

276

277

278

279

280

281

282

283

284

285

286

287

288// Licensed under the MIT License .

// Copyright (c) 2018 Liu Xiao and Daniil Goncharov .

//

// Permission is hereby  granted, free of charge, to any  person obtaining a copy

// of this software and associated  documentation files (the "Software"), to deal

// in the Software  without restriction, including without  limitation the rights

// to  use, copy,  modify, merge,  publish, distribute,  sublicense, and/or  sell

// copies  of  the Software,  and  to  permit persons  to  whom  the Software  is

// furnished to do so, subject to the following conditions:

//

// The above copyright notice and this permission notice shall be included in all

// copies or substantial portions of the Software.

//

// THE SOFTWARE  IS PROVIDED "AS  IS", WITHOUT WARRANTY  OF ANY KIND,  EXPRESS OR

// IMPLIED,  INCLUDING BUT  NOT  LIMITED TO  THE  WARRANTIES OF  MERCHANTABILITY,

// FITNESS FOR  A PARTICULAR PURPOSE AND  NONINFRINGEMENT. IN NO EVENT  SHALL THE

// AUTHORS  OR COPYRIGHT  HOLDERS  BE  LIABLE FOR  ANY  CLAIM,  DAMAGES OR  OTHER

// LIABILITY, WHETHER IN AN ACTION OF  CONTRACT, TORT OR OTHERWISE, ARISING FROM,

// OUT OF OR IN CONNECTION WITH THE SOFTWARE  OR THE USE OR OTHER DEALINGS IN THE

// SOFTWARE.

#if defined(_MSC_VER)

#  pragma warning(push)

#  pragma warning(disable : 4996)

#endif

#include "TFUtils.hpp"

#include

#include

#include

#include

#include

#include

// Public functions

TFUtils::TFUtils()

{

init_error_code=MODEL_NOT_LOADED;

}

// Public functions

TFUtils::STATUSTFUtils::LoadModel(std::stringmodel_file)

{

// Load graph

graph_def=LoadGraphDef(model_file.c_str());

if(graph_def==nullptr){

std::cerr<

init_error_code=MODEL_LOAD_FAILED;

returnMODEL_LOAD_FAILED;

}

// Create session

sess=CreateSession(graph_def);

if(sess==nullptr){

init_error_code=SESSION_CREATE_FAILED;

std::cerr<

returnSESSION_CREATE_FAILED;

}

init_error_code=SUCCESS;

returninit_error_code;

}

TFUtils::~TFUtils()

{

if(sess)

CloseAndDeleteSession(sess);

if(graph_def)

TF_DeleteGraph(graph_def);

}

TF_OutputTFUtils::GetOperationByName(std::stringname,intidx){

return{TF_GraphOperationByName(graph_def,name.c_str()),idx};

}

TFUtils::STATUSTFUtils::RunSession(conststd::vector&inputs,conststd::vector&input_tensors,

conststd::vector&outputs,std::vector&output_tensors)

{

if(init_error_code!=SUCCESS)

returninit_error_code;

boolrun_ret=RunSession(sess,inputs,input_tensors,outputs,output_tensors);

if(run_ret==false)

returnFAILED_RUN_SESSION;

returnSUCCESS;

}

voidTFUtils::PrinStatus(STATUSstatus)

{

switch(status){

caseSUCCESS:

std::cout<

break;

caseSESSION_CREATE_FAILED:

std::cout<

break;

caseMODEL_LOAD_FAILED:

std::cout<

break;

caseFAILED_RUN_SESSION:

std::cout<

break;

caseMODEL_NOT_LOADED:

std::cout<

break;

default:

std::cout<

}

}

// Static functions

staticvoidDeallocateBuffer(void*data,size_t){

std::free(data);

}

staticTF_Buffer*ReadBufferFromFile(constchar*file){

constautof=std::fopen(file,"rb");

if(f==nullptr){

returnnullptr;

}

std::fseek(f,0,SEEK_END);

constautofsize=ftell(f);

std::fseek(f,0,SEEK_SET);

if(fsize<1){

std::fclose(f);

returnnullptr;

}

constautodata=std::malloc(fsize);

std::fread(data,fsize,1,f);

std::fclose(f);

TF_Buffer*buf=TF_NewBuffer();

buf->data=data;

buf->length=fsize;

buf->data_deallocator=DeallocateBuffer;

returnbuf;

}

// Private functions

TF_Graph*TFUtils::LoadGraphDef(constchar*file){

if(file==nullptr){

returnnullptr;

}

TF_Buffer*buffer=ReadBufferFromFile(file);

if(buffer==nullptr){

returnnullptr;

}

TF_Graph*graph=TF_NewGraph();

TF_Status*status=TF_NewStatus();

TF_ImportGraphDefOptions*opts=TF_NewImportGraphDefOptions();

TF_GraphImportGraphDef(graph,buffer,opts,status);

TF_DeleteImportGraphDefOptions(opts);

TF_DeleteBuffer(buffer);

if(TF_GetCode(status)!=TF_OK){

TF_DeleteGraph(graph);

graph=nullptr;

}

TF_DeleteStatus(status);

returngraph;

}

TF_Session*TFUtils::CreateSession(TF_Graph*graph){

TF_Status*status=TF_NewStatus();

TF_SessionOptions*options=TF_NewSessionOptions();

TF_Session*sess=TF_NewSession(graph,options,status);

TF_DeleteSessionOptions(options);

if(TF_GetCode(status)!=TF_OK){

TF_DeleteStatus(status);

returnnullptr;

}

returnsess;

}

boolTFUtils::CloseAndDeleteSession(TF_Session*sess){

TF_Status*status=TF_NewStatus();

TF_CloseSession(sess,status);

if(TF_GetCode(status)!=TF_OK){

TF_CloseSession(sess,status);

TF_DeleteSession(sess,status);

TF_DeleteStatus(status);

returnfalse;

}

TF_DeleteSession(sess,status);

if(TF_GetCode(status)!=TF_OK){

TF_DeleteStatus(status);

returnfalse;

}

TF_DeleteStatus(status);

returntrue;

}

boolTFUtils::RunSession(TF_Session*sess,

constTF_Output*inputs,TF_Tensor*const*input_tensors,std::size_tninputs,

constTF_Output*outputs,TF_Tensor**output_tensors,std::size_tnoutputs){

if(sess==nullptr||

inputs==nullptr||input_tensors==nullptr||

outputs==nullptr||output_tensors==nullptr){

returnfalse;

}

TF_Status*status=TF_NewStatus();

TF_SessionRun(sess,

nullptr,// Run options.

inputs,input_tensors,static_cast(ninputs),// Input tensors, input tensor values, number of inputs.

outputs,output_tensors,static_cast(noutputs),// Output tensors, output tensor values, number of outputs.

nullptr,0,// Target operations, number of targets.

nullptr,// Run metadata.

status// Output status.

);

if(TF_GetCode(status)!=TF_OK){

TF_DeleteStatus(status);

returnfalse;

}

TF_DeleteStatus(status);

returntrue;

}

boolTFUtils::RunSession(TF_Session*sess,

conststd::vector&inputs,conststd::vector&input_tensors,

conststd::vector&outputs,std::vector&output_tensors){

returnRunSession(sess,

inputs.data(),input_tensors.data(),input_tensors.size(),

outputs.data(),output_tensors.data(),output_tensors.size());

}

TF_Tensor*TFUtils::CreateTensor(TF_DataTypedata_type,

conststd::int64_t*dims,std::size_tnum_dims,

constvoid*data,std::size_tlen){

if(dims==nullptr||data==nullptr){

returnnullptr;

}

TF_Tensor*tensor=TF_AllocateTensor(data_type,dims,static_cast(num_dims),len);

if(tensor==nullptr){

returnnullptr;

}

void*tensor_data=TF_TensorData(tensor);

if(tensor_data==nullptr){

TF_DeleteTensor(tensor);

returnnullptr;

}

std::memcpy(tensor_data,data,std::min(len,TF_TensorByteSize(tensor)));

returntensor;

}

voidTFUtils::DeleteTensor(TF_Tensor*tensor){

if(tensor==nullptr){

return;

}

TF_DeleteTensor(tensor);

}

voidTFUtils::DeleteTensors(conststd::vector&tensors){

for(autot:tensors){

TF_DeleteTensor(t);

}

}

#if defined(_MSC_VER)

#  pragma warning(pop)

#endif

3、简单图的读取与预测

在前述文章 Tensorflow C++ 从训练到部署(2):简单图的保存、读取与 CMake 编译 中我们已经介绍了一个 c=a*b 的简单“网络”是如何计算的。

其中 Python 构建网络和预测部分就不重复了,详见该文所述。这里直接给出 C API 的代码:

文件名:simple/load_simple_net_c_api.cc

C++

// Licensed under the MIT License .

// Copyright (c) 2018 Liu Xiao and Daniil Goncharov .

//

// Permission is hereby granted, free of charge, to any person obtaining a copy

// of this software and associated documentation files (the "Software"), to deal

// in the Software without restriction, including without limitation the rights

// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell

// copies of the Software, and to permit persons to whom the Software is

// furnished to do so, subject to the following conditions:

//

// The above copyright notice and this permission notice shall be included in all

// copies or substantial portions of the Software.

//

// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR

// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,

// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE

// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER

// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,

// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE

// SOFTWARE.

#include "../utils/TFUtils.hpp"

#include

#include

int main(int argc, char* argv[])

{

if (argc != 2)

{

std::cerr << std::endl << "Usage: ./project path_to_graph.pb" << std::endl;

return 1;

}

std::string graph_path = argv[1];

// TFUtils init

TFUtils TFU;

TFUtils::STATUS status = TFU.LoadModel(graph_path);

if (status != TFUtils::SUCCESS) {

std::cerr << "Can't load graph" << std::endl;

return 1;

}

// Input Tensor Create

const std::vector<:int64_t> input_a_dims = {1, 1};

const std::vector input_a_vals = {2.0};

const std::vector<:int64_t> input_b_dims = {1, 1};

const std::vector input_b_vals = {3.0};

const std::vector input_ops = {TFU.GetOperationByName("a", 0),

TFU.GetOperationByName("b", 0)};

const std::vector input_tensors = {TFUtils::CreateTensor(TF_FLOAT, input_a_dims, input_a_vals),

TFUtils::CreateTensor(TF_FLOAT, input_b_dims, input_b_vals)};

// Output Tensor Create

const std::vector output_ops = {TFU.GetOperationByName("c", 0)};

std::vector output_tensors = {nullptr};

status = TFU.RunSession(input_ops, input_tensors,

output_ops, output_tensors);

TFUtils::PrinStatus(status);

if (status == TFUtils::SUCCESS) {

const std::vector<:vector>> data = TFUtils::GetTensorsData(output_tensors);

const std::vector result = data[0];

std::cout << "Output value: " << result[0] << std::endl;

} else {

std::cout << "Error run session";

return 2;

}

TFUtils::DeleteTensors(input_tensors);

TFUtils::DeleteTensors(output_tensors);

return 0;

}

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81// Licensed under the MIT License .

// Copyright (c) 2018 Liu Xiao and Daniil Goncharov .

//

// Permission is hereby  granted, free of charge, to any  person obtaining a copy

// of this software and associated  documentation files (the "Software"), to deal

// in the Software  without restriction, including without  limitation the rights

// to  use, copy,  modify, merge,  publish, distribute,  sublicense, and/or  sell

// copies  of  the Software,  and  to  permit persons  to  whom  the Software  is

// furnished to do so, subject to the following conditions:

//

// The above copyright notice and this permission notice shall be included in all

// copies or substantial portions of the Software.

//

// THE SOFTWARE  IS PROVIDED "AS  IS", WITHOUT WARRANTY  OF ANY KIND,  EXPRESS OR

// IMPLIED,  INCLUDING BUT  NOT  LIMITED TO  THE  WARRANTIES OF  MERCHANTABILITY,

// FITNESS FOR  A PARTICULAR PURPOSE AND  NONINFRINGEMENT. IN NO EVENT  SHALL THE

// AUTHORS  OR COPYRIGHT  HOLDERS  BE  LIABLE FOR  ANY  CLAIM,  DAMAGES OR  OTHER

// LIABILITY, WHETHER IN AN ACTION OF  CONTRACT, TORT OR OTHERWISE, ARISING FROM,

// OUT OF OR IN CONNECTION WITH THE SOFTWARE  OR THE USE OR OTHER DEALINGS IN THE

// SOFTWARE.

#include "../utils/TFUtils.hpp"

#include

#include

intmain(intargc,char*argv[])

{

if(argc!=2)

{

std::cerr<<:endl . path_to_graph.pb>

return1;

}

std::stringgraph_path=argv[1];

// TFUtils init

TFUtilsTFU;

TFUtils::STATUSstatus=TFU.LoadModel(graph_path);

if(status!=TFUtils::SUCCESS){

std::cerr<

return1;

}

// Input Tensor Create

conststd::vector<:int64_t>input_a_dims={1,1};

conststd::vectorinput_a_vals={2.0};

conststd::vector<:int64_t>input_b_dims={1,1};

conststd::vectorinput_b_vals={3.0};

conststd::vectorinput_ops={TFU.GetOperationByName("a",0),

TFU.GetOperationByName("b",0)};

conststd::vectorinput_tensors={TFUtils::CreateTensor(TF_FLOAT,input_a_dims,input_a_vals),

TFUtils::CreateTensor(TF_FLOAT,input_b_dims,input_b_vals)};

// Output Tensor Create

conststd::vectoroutput_ops={TFU.GetOperationByName("c",0)};

std::vectoroutput_tensors={nullptr};

status=TFU.RunSession(input_ops,input_tensors,

output_ops,output_tensors);

TFUtils::PrinStatus(status);

if(status==TFUtils::SUCCESS){

conststd::vector<:vector>>data=TFUtils::GetTensorsData(output_tensors);

conststd::vectorresult=data[0];

std::cout<

}else{

std::cout<

return2;

}

TFUtils::DeleteTensors(input_tensors);

TFUtils::DeleteTensors(output_tensors);

return0;

}

简单解释一下:

C++

TFUtils::STATUS status = TFU.LoadModel(graph_path);

1TFUtils::STATUSstatus=TFU.LoadModel(graph_path);

这一行是加载 pb 文件。

C

// Input Tensor Create

const std::vector<:int64_t> input_a_dims = {1, 1};

const std::vector input_a_vals = {2.0};

const std::vector<:int64_t> input_b_dims = {1, 1};

const std::vector input_b_vals = {3.0};

const std::vector input_ops = {TFU.GetOperationByName("a", 0),

TFU.GetOperationByName("b", 0)};

const std::vector input_tensors = {TFUtils::CreateTensor(TF_FLOAT, input_a_dims, input_a_vals),

TFUtils::CreateTensor(TF_FLOAT, input_b_dims, input_b_vals)};

// Output Tensor Create

const std::vector output_ops = {TFU.GetOperationByName("c", 0)};

std::vector output_tensors = {nullptr}

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16// Input Tensor Create

conststd::vector<:int64_t>input_a_dims={1,1};

conststd::vectorinput_a_vals={2.0};

conststd::vector<:int64_t>input_b_dims={1,1};

conststd::vectorinput_b_vals={3.0};

conststd::vectorinput_ops={TFU.GetOperationByName("a",0),

TFU.GetOperationByName("b",0)};

conststd::vectorinput_tensors={TFUtils::CreateTensor(TF_FLOAT,input_a_dims,input_a_vals),

TFUtils::CreateTensor(TF_FLOAT,input_b_dims,input_b_vals)};

// Output Tensor Create

conststd::vectoroutput_ops={TFU.GetOperationByName("c",0)};

std::vectoroutput_tensors={nullptr}

这一段是创建两个输入 tensor 以及输入的 ops。注意这里的 CreateTensor 在后面都需要调用 DeleteTensors 进行内存释放。输出的 tensors 还没创建先定义为 nullptr。

C++

status = TFU.RunSession(input_ops, input_tensors,

output_ops, output_tensors);

1

2status=TFU.RunSession(input_ops,input_tensors,

output_ops,output_tensors);

这一行是运行网络。

C++

const std::vector<:vector>> data = TFUtils::GetTensorsData(output_tensors);

const std::vector result = data[0];

1

2conststd::vector<:vector>>data=TFUtils::GetTensorsData(output_tensors);

conststd::vectorresult=data[0];

这两行是从输出的 output_tensors 读取数据到一个二维vector const std::vector>,我们这里输出只有 "c" 一个名字,而且只有一个索引 0,因此直接取出 data[0] 就是我们原本想要的输出。

编译运行这一文件,如果没有问题则会得到如下输出:

Shell

status = SUCCESS

Output value: 6

1

2status=SUCCESS

Outputvalue:6

4、CNN的读取与预测

与刚才小节3相似,CNN网络也是一样的流程,还是以最基本的 fashion_mnist 为例,该网络的训练和保存流程请参考之前的文章。这里我们仅介绍 C API 进行预测的部分。由于我们这里需要读取一幅图并转化成 Tensor 输入网络,我们构造一个简单的函数 Mat2Tensor 实现这一转换:

1)Met2Tensor 部分文件:fashion_mnist/utils/mat2tensor_c_cpi.h

C++

// Licensed under the MIT License .

// Copyright (c) 2018 Liu Xiao .

//

// Permission is hereby granted, free of charge, to any person obtaining a copy

// of this software and associated documentation files (the "Software"), to deal

// in the Software without restriction, including without limitation the rights

// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell

// copies of the Software, and to permit persons to whom the Software is

// furnished to do so, subject to the following conditions:

//

// The above copyright notice and this permission notice shall be included in all

// copies or substantial portions of the Software.

//

// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR

// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,

// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE

// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER

// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,

// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE

// SOFTWARE.

#ifndef TENSORFLOW_CPP_MAT2TENSOR_C_H

#define TENSORFLOW_CPP_MAT2TENSOR_C_H

#include

#include

#include

#include

#include "opencv2/core/core.hpp"

TF_Tensor* Mat2Tensor(cv::Mat &img, float normal = 1/255.0) {

const std::vector<:int64_t> input_dims = {1, img.size().height, img.size().width, img.channels()};

// Convert to float 32 and do normalize ops

cv::Mat fake_mat(img.rows, img.cols, CV_32FC(img.channels()));

img.convertTo(fake_mat, CV_32FC(img.channels()));

fake_mat *= normal;

TF_Tensor* image_input = TFUtils::CreateTensor(TF_FLOAT,

input_dims.data(), input_dims.size(),

fake_mat.data, (fake_mat.size().height * fake_mat.size().width * fake_mat.channels() * sizeof(float)));

return image_input;

}

#endif //TENSORFLOW_CPP_MAT2TENSOR_C_H

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49// Licensed under the MIT License .

// Copyright (c) 2018 Liu Xiao .

//

// Permission is hereby  granted, free of charge, to any  person obtaining a copy

// of this software and associated  documentation files (the "Software"), to deal

// in the Software  without restriction, including without  limitation the rights

// to  use, copy,  modify, merge,  publish, distribute,  sublicense, and/or  sell

// copies  of  the Software,  and  to  permit persons  to  whom  the Software  is

// furnished to do so, subject to the following conditions:

//

// The above copyright notice and this permission notice shall be included in all

// copies or substantial portions of the Software.

//

// THE SOFTWARE  IS PROVIDED "AS  IS", WITHOUT WARRANTY  OF ANY KIND,  EXPRESS OR

// IMPLIED,  INCLUDING BUT  NOT  LIMITED TO  THE  WARRANTIES OF  MERCHANTABILITY,

// FITNESS FOR  A PARTICULAR PURPOSE AND  NONINFRINGEMENT. IN NO EVENT  SHALL THE

// AUTHORS  OR COPYRIGHT  HOLDERS  BE  LIABLE FOR  ANY  CLAIM,  DAMAGES OR  OTHER

// LIABILITY, WHETHER IN AN ACTION OF  CONTRACT, TORT OR OTHERWISE, ARISING FROM,

// OUT OF OR IN CONNECTION WITH THE SOFTWARE  OR THE USE OR OTHER DEALINGS IN THE

// SOFTWARE.

#ifndef TENSORFLOW_CPP_MAT2TENSOR_C_H

#define TENSORFLOW_CPP_MAT2TENSOR_C_H

#include

#include

#include

#include

#include "opencv2/core/core.hpp"

TF_Tensor*Mat2Tensor(cv::Mat&img,floatnormal=1/255.0){

conststd::vector<:int64_t>input_dims={1,img.size().height,img.size().width,img.channels()};

// Convert to float 32 and do normalize ops

cv::Matfake_mat(img.rows,img.cols,CV_32FC(img.channels()));

img.convertTo(fake_mat,CV_32FC(img.channels()));

fake_mat*=normal;

TF_Tensor*image_input=TFUtils::CreateTensor(TF_FLOAT,

input_dims.data(),input_dims.size(),

fake_mat.data,(fake_mat.size().height*fake_mat.size().width*fake_mat.channels()*sizeof(float)));

returnimage_input;

}

#endif //TENSORFLOW_CPP_MAT2TENSOR_C_H

2)网络读取与预测,这部分与刚才的小节3基本一样,就不做解释了:

C++

// Licensed under the MIT License .

// Copyright (c) 2018 Liu Xiao and Daniil Goncharov .

//

// Permission is hereby granted, free of charge, to any person obtaining a copy

// of this software and associated documentation files (the "Software"), to deal

// in the Software without restriction, including without limitation the rights

// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell

// copies of the Software, and to permit persons to whom the Software is

// furnished to do so, subject to the following conditions:

//

// The above copyright notice and this permission notice shall be included in all

// copies or substantial portions of the Software.

//

// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR

// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,

// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE

// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER

// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,

// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE

// SOFTWARE.

#include "../utils/TFUtils.hpp"

#include "utils/mat2tensor_c_cpi.h"

#include

#include

// OpenCV

#include

#include

//std::string class_names[10] = {'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'};

std::string class_names[] = {"T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"};

int ArgMax(const std::vector result);

int main(int argc, char* argv[])

{

if (argc != 3)

{

std::cerr << std::endl << "Usage: ./project path_to_graph.pb path_to_image.png" << std::endl;

return 1;

}

// Load graph

std::string graph_path = argv[1];

// TFUtils init

TFUtils TFU;

TFUtils::STATUS status = TFU.LoadModel(graph_path);

if (status != TFUtils::SUCCESS) {

std::cerr << "Can't load graph" << std::endl;

return 1;

}

// Load image and convert to tensor

std::string image_path = argv[2];

cv::Mat image = cv::imread(image_path, CV_LOAD_IMAGE_GRAYSCALE);

const std::vector<:int64_t> input_dims = {1, image.size().height, image.size().width, image.channels()};

TF_Tensor* input_image = Mat2Tensor(image, 1/255.0);

// Input Tensor/Ops Create

const std::vector input_tensors = {input_image};

const std::vector input_ops = {TFU.GetOperationByName("input_image_input", 0)};

// Output Tensor/Ops Create

const std::vector output_ops = {TFU.GetOperationByName("output_class/Softmax", 0)};

std::vector output_tensors = {nullptr};

status = TFU.RunSession(input_ops, input_tensors,

output_ops, output_tensors);

if (status == TFUtils::SUCCESS) {

const std::vector<:vector>> data = TFUtils::GetTensorsData(output_tensors);

const std::vector result = data[0];

int pred_index = ArgMax(result);

// Print test accuracy

printf("Predict: %d Label: %s", pred_index, class_names[pred_index].c_str());

} else {

std::cout << "Error run session";

return 2;

}

TFUtils::DeleteTensors(input_tensors);

TFUtils::DeleteTensors(output_tensors);

return 0;

}

int ArgMax(const std::vector result)

{

float max_value = -1.0;

int max_index = -1;

const long count = result.size();

for (int i = 0; i < count; ++i) {

const float value = result[i];

if (value > max_value) {

max_index = i;

max_value = value;

}

std::cout << "value[" << i << "] = " << value << std::endl;

}

return max_index;

}

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112// Licensed under the MIT License .

// Copyright (c) 2018 Liu Xiao and Daniil Goncharov .

//

// Permission is hereby  granted, free of charge, to any  person obtaining a copy

// of this software and associated  documentation files (the "Software"), to deal

// in the Software  without restriction, including without  limitation the rights

// to  use, copy,  modify, merge,  publish, distribute,  sublicense, and/or  sell

// copies  of  the Software,  and  to  permit persons  to  whom  the Software  is

// furnished to do so, subject to the following conditions:

//

// The above copyright notice and this permission notice shall be included in all

// copies or substantial portions of the Software.

//

// THE SOFTWARE  IS PROVIDED "AS  IS", WITHOUT WARRANTY  OF ANY KIND,  EXPRESS OR

// IMPLIED,  INCLUDING BUT  NOT  LIMITED TO  THE  WARRANTIES OF  MERCHANTABILITY,

// FITNESS FOR  A PARTICULAR PURPOSE AND  NONINFRINGEMENT. IN NO EVENT  SHALL THE

// AUTHORS  OR COPYRIGHT  HOLDERS  BE  LIABLE FOR  ANY  CLAIM,  DAMAGES OR  OTHER

// LIABILITY, WHETHER IN AN ACTION OF  CONTRACT, TORT OR OTHERWISE, ARISING FROM,

// OUT OF OR IN CONNECTION WITH THE SOFTWARE  OR THE USE OR OTHER DEALINGS IN THE

// SOFTWARE.

#include "../utils/TFUtils.hpp"

#include "utils/mat2tensor_c_cpi.h"

#include

#include

// OpenCV

#include

#include

//std::string class_names[10] = {'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'};

std::stringclass_names[]={"T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot"};

intArgMax(conststd::vectorresult);

intmain(intargc,char*argv[])

{

if(argc!=3)

{

std::cerr<<:endl . path_to_graph.pb path_to_image.png>

return1;

}

// Load graph

std::stringgraph_path=argv[1];

// TFUtils init

TFUtilsTFU;

TFUtils::STATUSstatus=TFU.LoadModel(graph_path);

if(status!=TFUtils::SUCCESS){

std::cerr<

return1;

}

// Load image and convert to tensor

std::stringimage_path=argv[2];

cv::Matimage=cv::imread(image_path,CV_LOAD_IMAGE_GRAYSCALE);

conststd::vector<:int64_t>input_dims={1,image.size().height,image.size().width,image.channels()};

TF_Tensor*input_image=Mat2Tensor(image,1/255.0);

// Input Tensor/Ops Create

conststd::vectorinput_tensors={input_image};

conststd::vectorinput_ops={TFU.GetOperationByName("input_image_input",0)};

// Output Tensor/Ops Create

conststd::vectoroutput_ops={TFU.GetOperationByName("output_class/Softmax",0)};

std::vectoroutput_tensors={nullptr};

status=TFU.RunSession(input_ops,input_tensors,

output_ops,output_tensors);

if(status==TFUtils::SUCCESS){

conststd::vector<:vector>>data=TFUtils::GetTensorsData(output_tensors);

conststd::vectorresult=data[0];

intpred_index=ArgMax(result);

// Print test accuracy

printf("Predict: %d Label: %s",pred_index,class_names[pred_index].c_str());

}else{

std::cout<

return2;

}

TFUtils::DeleteTensors(input_tensors);

TFUtils::DeleteTensors(output_tensors);

return0;

}

intArgMax(conststd::vectorresult)

{

floatmax_value=-1.0;

intmax_index=-1;

constlongcount=result.size();

for(inti=0;i

constfloatvalue=result[i];

if(value>max_value){

max_index=i;

max_value=value;

}

std::cout<

}

returnmax_index;

}

编译运行这一文件,如果没有问题,则会得到如下输出:

C

value[0] = 6.40457e-09

value[1] = 2.41816e-07

value[2] = 3.60118e-08

value[3] = 1.18324e-09

value[4] = 6.13108e-11

value[5] = 0.00021271

value[6] = 2.01991e-11

value[7] = 3.94614e-05

value[8] = 1.17029e-10

value[9] = 0.999748

Predict: 9 Label: Ankle boot

1

2

3

4

5

6

7

8

9

10

11value[0]=6.40457e-09

value[1]=2.41816e-07

value[2]=3.60118e-08

value[3]=1.18324e-09

value[4]=6.13108e-11

value[5]=0.00021271

value[6]=2.01991e-11

value[7]=3.94614e-05

value[8]=1.17029e-10

value[9]=0.999748

Predict:9Label:Ankleboot

到此,我们就完成了使用 C API 运行 Tensorflow Model 的流程。

Original content here is published under these license terms:X

License Type:Read Only

License Abstract:You may read the original content in the context in which it is published (at this web address). No other copying or use is permitted without written agreement from the author.

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值