python重写tf.nn.ctc_greedy_decoder

greedy search

参考链接:https://zhuanlan.zhihu.com/p/39266552
基本原理就是将每个时间 t 内最大概率的 k 取出即可。
下面通过一个例子来阐述:

y 的分布如下:
在这里插入图片描述
那么greedy search的结果为
在这里插入图片描述
例如当 t=1时,在序列 [0.25,0.4,0.35]中得到最大概率为0.4,依次找到各时间内的最大概率即可。

代码实现:
应用在crnn_ctc文字识别中
输入inference_output大小为[25,1,37],其中25为单词宽度,1为batch_size, 37为字符类别。

import json
import numpy as np


class GreedyDecoder(object):
    def __init__(self, char_dict_path, ord_map_dict_path):
        self._char_dict = self.read_json(char_dict_path)
        self._ord_map = self.read_json(ord_map_dict_path)

    def read_json(self, dict_path):
        """
        :param dict_path:
        :return: a dict with ord(char) as key and char as value
        """
        with open(dict_path, 'r', encoding='utf-8') as json_f:
            res = json.load(json_f)
        return res

    def int_to_char(self, number):
        """
        convert the int index into char
        :param number: Can be passed as string representing the integer value to look up.
        :return: Character corresponding to 'number' in the char_dict
        """
        # 1 is the default value in sparse_tensor_to_str() This will be skipped when building the resulting strings
        if number == 1 or number == '1':
            return '\x00'
        else:
            return self._char_dict[str(number) + '_ord']

    def remove_repeat(self, repeat_char_list):
    #删除重复字符和空格
        res_batch = []
        for charlist in repeat_char_list:
            pre_ch = 36  # 36代表空格
            res = []
            for ch in charlist:
                if ch != 36:
                    if ch != pre_ch:
                        res.append(ch)
                        pre_ch = ch  # 记录前一个字符
                else:
                    pre_ch = ch
            res_batch.append(res)
        return res_batch

    def greedy_decode(self, inference_output):
        pred_values = []
        pred_list = np.argmax(inference_output, axis=2)
        pred_batch = len(pred_list[0])
        for i in range(pred_batch):
            pred_values.append(pred_list[:, i].tolist())
        preds = self.remove_repeat(pred_values)  # 去除重复字符和空格
        number_lists = []
        for pred in preds:
            number_lists.append(np.array([self._ord_map[str(tmp) + '_index'] for tmp in pred]))  # 将类别转换为ascii
        str_lists = []
        res = []
        for number_list in number_lists:
            # Translate from ord() values into characters
            str_lists.append([self.int_to_char(val) for val in number_list])
        for str_list in str_lists:
            # int_to_char() returns '\x00' for an input == 1, which is the default
            # value in number_lists, so we skip it when building the result
            res.append(''.join(c for c in str_list if c != '\x00'))
        return res

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值