写在前面:经过几天的折腾,终于通关了TFLite
本地编译,tensorflow
模型转TFLite
,TFlite
推理CRNN
网络模型。现在把整个过程记录下来。
环境:
操作系统:ubuntu18.04.5
tensorflow版本:v2.2.0
bazel版本:3.1.0
1.TFLite本地编译
TFLite本地编译可以参考我的前面一篇博客《ubuntu18.04 tensorflow以及tensorflow lite源码编译C++库》,地址:https://blog.csdn.net/guo1988kui/article/details/103696188
2.tensorflow模型转TFLite
converter=tf.lite.TFLiteConverter.from_saved_model('./pb_model')
tflite_model=converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
./pb_model
是采用tf.saved_model.save(model, './pb_model')
方式保存的pb
格式模型。
3.TFLite推理CRNN网络
3.1.出现的问题—缺少flatbuffers.h头文件
解决办法:进入tensorflow-master/tensorflow/lite/micro/tools/make
文件夹,找到flatbuffers_download.sh
文件,终端输入命令:sudo ./flatbuffers_download.sh
,下载的文件存在于/tmp
目录下,名为
dca12522a9f9e37f126ab925fd385c807ab4f84e
打开flatbuffers_download.sh
我们可以看到
3.2.本地C++部署过程中需要包含以下头文件
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/model_builder.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/tools/gen_op_registration.h"
3.3.本地推理过程主要代码片段
#初始化
static tflite::ops::builtin::BuiltinOpResolver resolver;
static unique_ptr<tflite::Interpreter> interpreter;
static int in_index=-1;
static int out_index=-1;
static unique_ptr<FlatBufferModel> model=tflite::FlatBufferModel::BuildFromFile(model_path.c_str());
InterpreterBuilder(*model,resolver)(&interpreter);
if(interpreter == nullptr)
{
cerr<<"Error get interpreter!"<<endl;
return false;
}
interpreter->SetAllowFp16PrecisionForFp32(true);
interpreter->SetNumThreads(4);
interpreter->AllocateTensors();
if(interpreter->AllocateTensors() != kTfLiteOk)
{
cerr<<"Error AllocateTensors!"<<endl;
return false;
}
in_index = interpreter->inputs()[0];
out_index = interpreter->outputs()[0];
#推理过程
static string num_dict="0123456789-";
memcpy(interpreter->typed_tensor<float>(in_index), &((float*)img.data)[0], img.cols*img.rows*sizeof(float));
interpreter->Invoke();
float* output = interpreter->typed_tensor<float>(out_index);
#CRNN输出的维度57,11,这里要改成你自己的维度
int rows=57;
int cols=11;
vector<int> recognition_result;
for (size_t i = 0; i < rows; i++)
{
float sum=0;
vector<pair<float,int> > row_result;
for (size_t j = 0; j < cols; j++)
{
row_result.emplace_back(pair<float,int>(output[i*cols+j],j));
sum+=exp(output[i*cols+j]);
}
sort(row_result.begin(),row_result.end(),cmp);
float prob=exp(row_result[0].first)/sum;
if(prob>0.9)
recognition_result.emplace_back(row_result[0].second);
}
string str_result="";
for (size_t i = 0; i < recognition_result.size(); i++)
{
if(recognition_result[i]!=10)
str_result+=num_dict[recognition_result[i]];
}
为了节省时间实际上,CTC
的解码过程退化为取每一个时间片中最大概率的类别。