greedy search
参考链接:https://zhuanlan.zhihu.com/p/39266552
基本原理就是将每个时间 t 内最大概率的 k 取出即可。
下面通过一个例子来阐述:
y 的分布如下:
那么greedy search的结果为
例如当 t=1时,在序列 [0.25,0.4,0.35]中得到最大概率为0.4,依次找到各时间内的最大概率即可。
代码实现:
应用在crnn_ctc文字识别中
输入inference_output大小为[25,1,37],其中25为单词宽度,1为batch_size, 37为字符类别。
import json
import numpy as np
class GreedyDecoder(object):
def __init__(self, char_dict_path, ord_map_dict_path):
self._char_dict = self.read_json(char_dict_path)
self._ord_map = self.read_json(ord_map_dict_path)
def read_json(self, dict_path):
"""
:param dict_path:
:return: a dict with ord(char) as key and char as value
"""
with open(dict_path, 'r', encoding='utf-8') as json_f:
res = json.load(json_f)
return res
def int_to_char(self, number):
"""
convert the int index into char
:param number: Can be passed as string representing the integer value to look up.
:return: Character corresponding to 'number' in the char_dict
"""
# 1 is the default value in sparse_tensor_to_str() This will be skipped when building the resulting strings
if number == 1 or number == '1':
return '\x00'
else:
return self._char_dict[str(number) + '_ord']
def remove_repeat(self, repeat_char_list):
#删除重复字符和空格
res_batch = []
for charlist in repeat_char_list:
pre_ch = 36 # 36代表空格
res = []
for ch in charlist:
if ch != 36:
if ch != pre_ch:
res.append(ch)
pre_ch = ch # 记录前一个字符
else:
pre_ch = ch
res_batch.append(res)
return res_batch
def greedy_decode(self, inference_output):
pred_values = []
pred_list = np.argmax(inference_output, axis=2)
pred_batch = len(pred_list[0])
for i in range(pred_batch):
pred_values.append(pred_list[:, i].tolist())
preds = self.remove_repeat(pred_values) # 去除重复字符和空格
number_lists = []
for pred in preds:
number_lists.append(np.array([self._ord_map[str(tmp) + '_index'] for tmp in pred])) # 将类别转换为ascii
str_lists = []
res = []
for number_list in number_lists:
# Translate from ord() values into characters
str_lists.append([self.int_to_char(val) for val in number_list])
for str_list in str_lists:
# int_to_char() returns '\x00' for an input == 1, which is the default
# value in number_lists, so we skip it when building the result
res.append(''.join(c for c in str_list if c != '\x00'))
return res