2021SC@SDUSC
代码位置:tools->infer->predict_det.py
class TextDetector(object):
def __init__(self, args):
self.args = args
self.det_algorithm = args.det_algorithm
pre_process_list = [{
'DetResizeForTest': {
'limit_side_len': args.det_limit_side_len,
'limit_type': args.det_limit_type,
}
}, {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
'mean': [0.485, 0.456, 0.406],
'scale': '1./255.',
'order': 'hwc'
}
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': ['image', 'shape']
}
}]
postprocess_params = {}
这里是一些参数的初始化。
if self.det_algorithm == "DB":
postprocess_params['name'] = 'DBPostProcess'
postprocess_params["thresh"] = args.det_db_thresh
postprocess_params["box_thresh"] = args.det_db_box_thresh
postprocess_params["max_candidates"] = 1000
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = args.use_dilation
postprocess_params["score_mode"] = args.det_db_score_mode
elif self.det_algorithm == "EAST":
postprocess_params['name'] = 'EASTPostProcess'
postprocess_params["score_thresh"] = args.det_east_score_thresh
postprocess_params["cover_thresh"] = args.det_east_cover_thresh
postprocess_params["nms_thresh"] = args.det_east_nms_thresh
elif self.det_algorithm == "SAST":
pre_process_list[0] = {
'DetResizeForTest': {
'resize_long': args.det_limit_side_len
}
}
postprocess_params['name'] = 'SASTPostProcess'
postprocess_params["score_thresh"] = args.det_sast_score_thresh
postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
self.det_sast_polygon = args.det_sast_polygon
if self.det_sast_polygon:
postprocess_params["sample_pts_num"] = 6
postprocess_params["expand_scale"] = 1.2
postprocess_params["shrink_ratio_of_width"] = 0.2
else:
postprocess_params["sample_pts_num"] = 2
postprocess_params["expand_scale"] = 1.0
postprocess_params["shrink_ratio_of_width"] = 0.3
else:
logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
sys.exit(0)
这是对三大文字检测算法的选择,即EAST算法,SAST算法和DB算法。
-
EAST算法:
- 简单而强大的pipeLine,可以在自然场景中进行快速准确的文本检测。该Pipeline直接预测图像中任意方向和矩形形状的文本或文本行,通过单个神经网络消除不必要的中间步骤(例如候选聚合和单词分割)。
- 仅有两个阶段:一个阶段是基于全卷积网络(FCN)模型,直接产生文本框预测;第二个阶段是对生成的文本预测框(可以是旋转矩形或矩形)经过非极大值抑制以产生最终结果。该模型放弃了不必要的中间步骤,进行端到端的训练和优化。
3. 算法框架:首先,将图像送到FCN网络结构中并且生成单通道像素级的文本分数特征图和多通道几何图形特征图。文本区域采用了两种几何形状:旋转框(RBOX)和水平(QUAD),并为每个几何形状设计了不同的损失函数;然后,将阈值应用于每个预测区域,其中评分超过预定阈值的几何形状被认为是有效的,并且保存以用于随后的非极大抑制。NMS之后的结果被认为是pipeline的最终结果。
-
SAST算法:百度自研文字检测算法,一种基于分割的one-shot任意形状文本检测器
-
实际上就是EAST算法的扩展,一阶段,输出为multitask,各个分支相互校正。
-
它利用基于全卷积网络(FCN)的上下文多任务学习框架来学习文本区域的各种几何特征,从而构造文本区域的多边形表示。考虑到文本的连续性特征,通过引入Context Attention Block 捕捉像素的长范围相关性,一次来获得更加可靠的分割结果。在后处理过程中,提出一个点到边的对齐方法,来将像素聚类称为文本实力,这样就通过一次采样图片,把高级别的特征和低级别的特征结合在一起。此外,利用所提出的几何性质可以更有效地提取任意形状文本的多边形表示。
-
def order_points_clockwise(self, pts):
xSorted = pts[np.argsort(pts[:, 0]), :]
leftMost = xSorted[:2, :]
rightMost = xSorted[2:, :]
leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
(tl, bl) = leftMost
rightMost = rightMost[np.argsort(rightMost[:, 1]), :]
(tr, br) = rightMost
rect = np.array([tl, tr, br, bl], dtype="float32")
return rect
作用:顺时针对点进行排序,首先根据横坐标找到最左边和最右边的的点,然后针对最左边的数据根据他们的纵坐标进行排序,我们可以找到最上面和最下面的点
代码位置:deploy->hubserving->ocr_det->module.py
def predict(self, images=[], paths=[]):
if images != [] and isinstance(images, list) and paths == []:
predicted_data = images
elif images == [] and isinstance(paths, list) and paths != []:
predicted_data = self.read_images(paths)
else:
raise TypeError("The input data is inconsistent with expectations.")
assert predicted_data != [], "There is not any image to be predicted. Please check the input data."
all_results = []
for img in predicted_data:
if img is None:
logger.info("error in loading image")
all_results.append([])
continue
dt_boxes, elapse = self.text_detector(img)
logger.info("Predict time : {}".format(elapse))
rec_res_final = []
for dno in range(len(dt_boxes)):
rec_res_final.append({
'text_region': dt_boxes[dno].astype(np.int).tolist()
})
all_results.append(rec_res_final)
return all_results
作用:获取预测图像中的文本框。
输入参数:def predict(self, images=[], paths=[])
- self:实例对象本身
- images:一个列表,且每个图像的形状为[H,W,C]
- paths: 图像的路径,文本检测框的结果和图像的保存路径。
for dno in range(len(dt_boxes)):
rec_res_final.append({
'text_region': dt_boxes[dno].astype(np.int).tolist()
})