Pytorch crnn 笔记(三)

94 篇文章 12 订阅
4 篇文章 1 订阅

本想自己从头写起,查了一下有人实现过,那我就只剩验证和改善的工作了。

参考博客:Pytorch模型部署 - Libtorch(crnn模型部署)

Step1: 模型转换

将pytorch训练好的crnn模型转换为libtorch能够读取的模型.

#covertion.py
import torch
import torchvison

model = CRNN(32, 1, len(keys.alphabetEnglish) + 1, 256, 1).cpu()

state_dict = torch.load(
    model_path, map_location=lambda storage, loc: storage)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k.replace('module.', '')  # remove `module.`
    new_state_dict[name] = v
# # # load params
model.load_state_dict(new_state_dict)

# convert pth-model to pt-model
example = torch.rand(1, 1, 32, 512)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("src/crnn.pt")

代码过长,github附完整代码。github: crnn_libtorch

Step2: 模型部署

利用libtoch+opencv实现对文字条的识别.

//crnnDeploy.h
#include <torch/torch.h>
#include <torch/script.h>
#include <opencv2/highgui.hpp>
#include <opencv2/imgproc.hpp>

#include <iostream>
#include <cassert>
#include <vector>

#ifndef CRNN_H
#define CRNN_H

class Crnn{
    public:
        Crnn(std::string& modelFile, std::string& keyFile);
        torch::Tensor loadImg(std::string& imgFile, bool isbath=false);
        void infer(torch::Tensor& input);
    private:
        torch::jit::script::Module m_module;
        std::vector<std::string> m_keys;
        std::vector<std::string> readKeys(const std::string& keyFile);
        torch::jit::script::Module loadModule(const std::string& modelFile);
};

#endif//CRNN_H
/*
@author
date: 2020-03-17
Introduce:
    Deploy crnn model with libtorch.
*/

#include "CrnnDeploy.h"
#include <thread>
#include <sys/time.h>

//construtor
Crnn::Crnn(std::string& modelFile, std::string& keyFile){
    this->m_module = this->loadModule(modelFile);
    this->m_keys = this->readKeys(keyFile);
}


torch::Tensor Crnn::loadImg(std::string& imgFile, bool isbath){
	cv::Mat input = cv::imread(imgFile, 0);
	if(!input.data){
		printf("Error: not image data, imgFile input wrong!!");
	}
	int resize_h = int(input.cols * 32 / input.rows);
	cv::resize(input, input, cv::Size(resize_h, 32));
    torch::Tensor imgTensor;
    if(isbath){
        imgTensor = torch::from_blob(input.data, {32, resize_h, 1}, torch::kByte);
	    imgTensor = imgTensor.permute({2,0,1});
    }else
    {
        imgTensor = torch::from_blob(input.data, {1,32, resize_h, 1}, torch::kByte);
        imgTensor = imgTensor.permute({0,3,1,2});
    }
	imgTensor = imgTensor.toType(torch::kFloat);
	imgTensor = imgTensor.div(255);
	imgTensor = imgTensor.sub(0.5);
	imgTensor = imgTensor.div(0.5);
    return imgTensor;
}

void Crnn::infer(torch::Tensor& input){
    torch::Tensor output = this->m_module.forward({input}).toTensor();
    std::vector<int> predChars;
    int numImgs = output.sizes()[1];
    if(numImgs == 1){
        for(uint i=0; i<output.sizes()[0]; i++){
            auto maxRes = output[i].max(1, true);
            int maxIdx = std::get<1>(maxRes).item<float>();
            predChars.push_back(maxIdx);
        }
        // 字符转录处理
        std::string realChars="";
        for(uint i=0; i<predChars.size(); i++){
            if(predChars[i] != 0){
                if(!(i>0 && predChars[i-1]==predChars[i])){
                    realChars += this->m_keys[predChars[i]];
                }
            }
        }
        std::cout << realChars << std::endl;
    }else
    {
        std::vector<std::string> realCharLists;
        std::vector<std::vector<int>> predictCharLists;

        for (int i=0; i<output.sizes()[1]; i++){
            std::vector<int> temp;
            for(int j=0; j<output.sizes()[0]; j++){
                auto max_result = (output[j][i]).max(0, true);
                int max_index = std::get<1>(max_result).item<float>();//predict value
                temp.push_back(max_index);
            }
            predictCharLists.push_back(temp);
        }

        for(auto vec : predictCharLists){
            std::string text = "";
            for(uint i=0; i<vec.size(); i++){
                if(vec[i] != 0){
                    if(!(i>0 && vec[i-1]==vec[i])){
                        text += this->m_keys[vec[i]];
                    }
                }
            }
            realCharLists.push_back(text);
        }
        for(auto t : realCharLists){
            std::cout << t << std::endl;
        }
    }

}

std::vector<std::string> Crnn::readKeys(const std::string& keyFile){
    std::ifstream in(keyFile);
	std::ostringstream tmp;
	tmp << in.rdbuf();
	std::string keys = tmp.str();

    std::vector<std::string> words;
    words.push_back(" ");//函数过滤掉了第一个空格,这里加上
    int len = keys.length();
    int i = 0;
    
    while (i < len) {
      assert ((keys[i] & 0xF8) <= 0xF0);
      int next = 1;
      if ((keys[i] & 0x80) == 0x00) {
      } else if ((keys[i] & 0xE0) == 0xC0) {
        next = 2;
      } else if ((keys[i] & 0xF0) == 0xE0) {
        next = 3;
      } else if ((keys[i] & 0xF8) == 0xF0) {
        next = 4;
      }
      words.push_back(keys.substr(i, next));
      i += next;
    } 
    return words;
}

torch::jit::script::Module Crnn::loadModule(const std::string& modelFile){
    torch::jit::script::Module module;
    try{
         module = torch::jit::load(modelFile);
    }catch(const c10::Error& e){
        std::cerr << "error loadding the model !!!\n";
    }
    return module;
}


long getCurrentTime(void){
    struct timeval tv;
    gettimeofday(&tv, NULL);
    return tv.tv_sec * 1000 + tv.tv_usec/1000;
}

int main(int argc, const char* argv[]){

    if(argc<4){
        printf("Error use CrnnDeploy: loss input param !!! \n");
        return -1;
    }
    std::string modelFile = argv[1];
    std::string keyFile = argv[2];
    std::string imgFile = argv[3];

    long t1 = getCurrentTime();
    Crnn* crnn = new Crnn(modelFile,keyFile);
    torch::Tensor input = crnn->loadImg(imgFile);
    crnn->infer(input);
    delete crnn;
    long t2 = getCurrentTime();

    printf("ocr time : %ld ms \n", (t2-t1));
    return 0;
}

验证结果(图片与识别结果):

 

 

 

 

 

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

落花逐流水

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值