TFLite推理CRNN网络模型,0-9数字识别

写在前面:经过几天的折腾,终于通关了TFLite本地编译,tensorflow模型转TFLiteTFlite推理CRNN网络模型。现在把整个过程记录下来。

环境:

操作系统:ubuntu18.04.5
tensorflow版本:v2.2.0
bazel版本:3.1.0

1.TFLite本地编译

TFLite本地编译可以参考我的前面一篇博客《ubuntu18.04 tensorflow以及tensorflow lite源码编译C++库》,地址:https://blog.csdn.net/guo1988kui/article/details/103696188

2.tensorflow模型转TFLite

converter=tf.lite.TFLiteConverter.from_saved_model('./pb_model')
tflite_model=converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)

./pb_model是采用tf.saved_model.save(model, './pb_model')方式保存的pb格式模型。

3.TFLite推理CRNN网络

3.1.出现的问题—缺少flatbuffers.h头文件

解决办法:进入tensorflow-master/tensorflow/lite/micro/tools/make文件夹,找到flatbuffers_download.sh文件,终端输入命令:sudo ./flatbuffers_download.sh,下载的文件存在于/tmp目录下,名为

dca12522a9f9e37f126ab925fd385c807ab4f84e

打开flatbuffers_download.sh我们可以看到
在这里插入图片描述

3.2.本地C++部署过程中需要包含以下头文件
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/model_builder.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/tools/gen_op_registration.h"
3.3.本地推理过程主要代码片段
#初始化
static tflite::ops::builtin::BuiltinOpResolver resolver;
static unique_ptr<tflite::Interpreter> interpreter;
static int in_index=-1;
static int out_index=-1;

static unique_ptr<FlatBufferModel> model=tflite::FlatBufferModel::BuildFromFile(model_path.c_str());
InterpreterBuilder(*model,resolver)(&interpreter);
if(interpreter == nullptr)
{
	cerr<<"Error get interpreter!"<<endl;
	return false;
}
interpreter->SetAllowFp16PrecisionForFp32(true);
interpreter->SetNumThreads(4);

interpreter->AllocateTensors();
if(interpreter->AllocateTensors() != kTfLiteOk)
{
	cerr<<"Error AllocateTensors!"<<endl;
	return false;
}
in_index = interpreter->inputs()[0];
out_index = interpreter->outputs()[0];
#推理过程
static string num_dict="0123456789-";
memcpy(interpreter->typed_tensor<float>(in_index), &((float*)img.data)[0], img.cols*img.rows*sizeof(float));

interpreter->Invoke();

float* output = interpreter->typed_tensor<float>(out_index);
#CRNN输出的维度57,11,这里要改成你自己的维度
int rows=57;
int cols=11;
vector<int> recognition_result;
for (size_t i = 0; i < rows; i++)
{
    float sum=0;
    vector<pair<float,int> > row_result;
    for (size_t j = 0; j < cols; j++)
    {
        row_result.emplace_back(pair<float,int>(output[i*cols+j],j));
        sum+=exp(output[i*cols+j]);
    }
    sort(row_result.begin(),row_result.end(),cmp);
    float prob=exp(row_result[0].first)/sum;
    
    if(prob>0.9)
        recognition_result.emplace_back(row_result[0].second);
}
string str_result="";
for (size_t i = 0; i < recognition_result.size(); i++)
{
    
    if(recognition_result[i]!=10)
        str_result+=num_dict[recognition_result[i]];
}

为了节省时间实际上,CTC的解码过程退化为取每一个时间片中最大概率的类别。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值