PaddleOCR 文字检测部分源码学习(10)-后处理(2)

2021SC@SDUSC
代码位置:ppocr\postprocess\east_postprocess.py

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from .locality_aware_nms import nms_locality
import cv2
import paddle

import os
import sys


class EASTPostProcess(object):
    """
    The post process for EAST.
    """
    def __init__(self,
                 score_thresh=0.8,
                 cover_thresh=0.1,
                 nms_thresh=0.2,
                 **kwargs):

        self.score_thresh = score_thresh
        self.cover_thresh = cover_thresh
        self.nms_thresh = nms_thresh
        
        # c++ la-nms is faster, but only support python 3.5
        self.is_python35 = False
        if sys.version_info.major == 3 and sys.version_info.minor == 5:
            self.is_python35 = True

    def restore_rectangle_quad(self, origin, geometry):
        """
        Restore rectangle from quadrangle.
        """
        # quad
        origin_concat = np.concatenate(
            (origin, origin, origin, origin), axis=1)  # (n, 8)
        pred_quads = origin_concat - geometry
        pred_quads = pred_quads.reshape((-1, 4, 2))  # (n, 4, 2)
        return pred_quads

    def detect(self,
               score_map,
               geo_map,
               score_thresh=0.8,
               cover_thresh=0.1,
               nms_thresh=0.2):
        """
        restore text boxes from score map and geo map
        """
        score_map = score_map[0]
        geo_map = np.swapaxes(geo_map, 1, 0)
        geo_map = np.swapaxes(geo_map, 1, 2)
        # filter the score map
        xy_text = np.argwhere(score_map > score_thresh)
        if len(xy_text) == 0:
            return []
        # sort the text boxes via the y axis
        xy_text = xy_text[np.argsort(xy_text[:, 0])]
        #restore quad proposals
        text_box_restored = self.restore_rectangle_quad(
            xy_text[:, ::-1] * 4, geo_map[xy_text[:, 0], xy_text[:, 1], :])
        boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
        boxes[:, :8] = text_box_restored.reshape((-1, 8))
        boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
        if self.is_python35:
            import lanms
            boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh)
        else:
            boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
        if boxes.shape[0] == 0:
            return []
        # Here we filter some low score boxes by the average score map, 
        #   this is different from the orginal paper.
        for i, box in enumerate(boxes):
            mask = np.zeros_like(score_map, dtype=np.uint8)
            cv2.fillPoly(mask, box[:8].reshape(
                (-1, 4, 2)).astype(np.int32) // 4, 1)
            boxes[i, 8] = cv2.mean(score_map, mask)[0]
        boxes = boxes[boxes[:, 8] > cover_thresh]
        return boxes

    def sort_poly(self, p):
        """
        Sort polygons.
        """
        min_axis = np.argmin(np.sum(p, axis=1))
        p = p[[min_axis, (min_axis + 1) % 4,\
            (min_axis + 2) % 4, (min_axis + 3) % 4]]
        if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]):
            return p
        else:
            return p[[0, 3, 2, 1]]

    def __call__(self, outs_dict, shape_list):
        score_list = outs_dict['f_score']
        geo_list = outs_dict['f_geo']
        if isinstance(score_list, paddle.Tensor):
            score_list = score_list.numpy()
            geo_list = geo_list.numpy()
        img_num = len(shape_list)
        dt_boxes_list = []
        for ino in range(img_num):
            score = score_list[ino]
            geo = geo_list[ino]
            boxes = self.detect(
                score_map=score,
                geo_map=geo,
                score_thresh=self.score_thresh,
                cover_thresh=self.cover_thresh,
                nms_thresh=self.nms_thresh)
            boxes_norm = []
            if len(boxes) > 0:
                h, w = score.shape[1:]
                src_h, src_w, ratio_h, ratio_w = shape_list[ino]
                boxes = boxes[:, :8].reshape((-1, 4, 2))
                boxes[:, :, 0] /= ratio_w
                boxes[:, :, 1] /= ratio_h
                for i_box, box in enumerate(boxes):
                    box = self.sort_poly(box.astype(np.int32))
                    if np.linalg.norm(box[0] - box[1]) < 5 \
                        or np.linalg.norm(box[3] - box[0]) < 5:
                        continue
                    boxes_norm.append(box)
            dt_boxes_list.append({'points': np.array(boxes_norm)})
        return dt_boxes_list

在这里插入图片描述
对于一个四边形 Q = {pi|i ∈ {1, 2, 3, 4}},pi = {xi, yi},表示四边形的顺时针方向的四个顶点,从上图可以看到,我们最终要的是(a)中绿色框的部分,即相对于黄色的虚线框缩小一定程度可以得到绿色的框。为了缩小Q,先对于每个顶点pi计算一个长度ri:
在这里插入图片描述

可以看到,上述公式要表达的意思就是该顶点连接的两条边取最短的那一条。缩小Q的过程大致如下:先缩小较长的两条边,然后缩小较短的两条边。对于每两个相对的边,决定谁是更长的一对的方式是比较他们长度的均值。对于每条边 (pi, p(i mod 4)+1),我们通过将其两个端点沿着边向内分别移动0.3ri和0.3r(i mod 4)+1。

geometry map的生成方式如下:首先用一个旋转的矩形,用最小的面积可以覆盖住该文本区域。然后对于每个有正例得分的像素,计算他到四个文本框边界的距离,并且使其作为RBOX的ground truth的4个通道,而对于QUAD ground truth,在其8通道的geometry map上每个正得分像素的值是它的四边形的四个顶点到它的坐标的偏移。

代码实现上有很多细节部分,例如参照了https://github.com/SakuraRiven/EAST这一份工程的代码:
在生成score_map的时候,是基于原图的四分之一大小(因为输出的特征图的大小就是四分之一原图大小),然后在缩小时先缩短较长的两边,假设该边edge1的两个点时v1,v2,对于v1和v2两点各自都连着两条边v1(r1,r2),v2(r2,r3),取和两个点相连的短边(可能是r1和r3,也可能都是r2,这里假设是r1和r3),然后分别将v1和v2两点沿着该边的方向向内移动0.3r1和0.3r3的长度。
在生成geo_map的时候,先每隔1度遍历所有角度(-90,90),找到最小外接矩形和矩形的角度,然后旋转文本框旋转到theta=0的水平状态,这样是为了方便计算d1,d2,d3,d4。
另外还有一个ignore_map,在ICDAR_2015数据集中,有些标签是“###”,这个表示是无法看清的一些文本,该代码中会把这些区域取出来,然后忽略掉这些区域,个人觉得这个也是为了防止误检。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值