ocr表格扭曲矫正

表格的扭曲矫正

在通过ocr对表格中的图像进行识别的时候发现,表格图像的定位是很重要的.虽然通过hough找直线,设低阈值,可以较高鲁棒性的找出表格的横线,但是表格如果存在一定程度的扭曲偏移,则检测切图的效果很差劲.看些论文发现,基于函数的,或者相机视角的矫正确实效果好.但是就目前的基于表格的应用来说还是越简单越好.
在这里插入图片描述
以上是矫正的效果图.思路很简单,

  • 通过十字模板去匹配表格交点.
  • 通过hough找出的直线的相对位置过滤掉不符合的交点.
  • 通过按列按行,对与交点进行排序
  • 通过交点的均值,找出标准网格中,交点的排布位置
  • 遍历网格,对网格单元进行从原图像到标准图像的映射转换

原图片
在这里插入图片描述
代码

class Choose:
    # 根据十字获取,表格交点坐标点
    def detectCross(self, imgt, hsize, wsize, thr=0.65):
        # 十字中心
        midh, midw = hsize // 2, wsize // 2
        # 十字模板
        template = np.zeros([hsize, wsize])
        template[midh - 1:midh + 2, :] = 255
        template[:, midw - 1:midw + 2] = 255
        match_res = cv2.matchTemplate(np.array(imgt, dtype=np.uint8), np.array(template, dtype=np.uint8),
                                      cv2.TM_CCORR_NORMED)
        # 根据阈值筛选顶点
        index = np.where(match_res > thr)
        points = []
        for y, x in zip(index[0], index[1]):
            points.append([y + midh, x + midw])
        return points

    # 获取表格网格点坐标点
    def getNetPoints(self, img):
        ah, aw = img.shape[:2]
        img = np.array(img)
        imgt = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (10, 10))
        imgt = cv2.erode(imgt, kernel, iterations=1)
        imgt = cv2.dilate(imgt, kernel, iterations=1)
        ret, _ = cv2.threshold(imgt, 0, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)
        ret, imgt = cv2.threshold(imgt, ret * 1.2, 255, cv2.THRESH_BINARY_INV)
        # showImage(imgt)
        # # 根据阈值大小和十字图像模板的长宽,查找对应点,寻找合适参数
        # for size in range(30, 95, 5):
        #     points = self.detectCross(imgt, size, size)
        #     imgx = np.array(img)
        #     for y, x in points:
        #         cv2.circle(imgx, (x, y), 5, [0, 0, 255], 1)
        #     showImage(imgx, str(size))
        points = self.detectCross(imgt, 40, 45, 0.65)  # 获取表格点
        # self.drawPoints(img, points)
        mask = np.zeros(imgt.shape[:2])
        for y, x in points:
            mask[y, x] = 1
        needh, needw = self.findDetailBox(img)  # 根据直线获取,表格参考线
        warr = sorted(list(set(np.array(needw).flatten())))
        for w in warr:
            sw, ew = w - 10, w + 10
            mask[:, sw:ew] = mask[:, sw:ew] + 1
        # 保留合适点竖排点
        needs_points = []
        for y, x in points:
            if mask[y, x] >= 2:
                needs_points.append([y, x])
        # self.drawPoints(img, needs_points)
        # for h1, h2 in needh:
        #     for w1, w2 in needw:
        #         cv2.rectangle(img, (w1, h1), (w2, h2), (0, 0, 255), 2)
        # showImage(img)
        # 按照先列后行的顺序,进行网格排序
        needs_points = np.array(needs_points)
        col_scale = []
        mg = 8
        col_scale.append([needw[0][0] - mg, needw[0][0] + mg])
        for i in range(1, len(needw)):
            front, cur = needw[i - 1], needw[i]
            col_scale.append([front[1] - mg, cur[0] + mg])
        gridx = []
        tmppoints = np.array(needs_points)[:, 0].flatten()
        minv, maxv = tmppoints.min(), tmppoints.max()
        scale = (maxv - minv) // len(needh)
        # print(scale, minv, maxv, len(needh))
        # idx = 0
        for s, e in col_scale:  # 过滤不在范围内的,按照列行,排列,与标准线最近的点
            where = np.where((needs_points[:, 1] > s) & (needs_points[:, 1] < e))
            arr = needs_points[where]
            if len(arr) == 0:
                continue
            arr = arr[np.argsort(arr[:, 0])]
            if len(arr) != len(needw):
                tarr = []
                for vh in range(minv, maxv + 1, scale):
                    dx = np.abs(arr[:, 0] - vh)
                    if len(np.where(dx < 30)[0]) == 0:
                        tmp = [vh, int(arr[:, 1].mean())]
                    else:
                        tmp = arr[dx.argmin(), :]
                    tarr.append(tmp)
                arr = tarr
            gridx.append(arr)
        gridx = np.array(gridx)
        # print(gridx.shape)
        # points = gridx.reshape([-1, 2])
        # self.drawPoints(img, points)
        return gridx

    # 画图看网格定位结果
    def drawPoints(self, img, points):
        print("size of points: ", len(points))
        for y, x in points:
            cv2.circle(img, (x, y), 3, (0, 0, 255), 3)
        showImage(img)

    # 表格网格化去扭曲
    def dewarp(self, img):
        try:
            ah, aw = img.shape[:2]
            points = self.getNetPoints(img)
            # print(points.shape)
            # showImage(img)
            cols, rows = points.shape[:2]
            standard = np.zeros(points.shape, dtype=np.int)  # 模板网格点
            # 根据表格点的均值,作为参考点
            for i in range(cols):
                w = int(points[i, :, 1].mean())
                standard[i, :, 1] = w
            for j in range(rows):
                h = int(points[:, j, 0].mean())
                standard[:, j, 0] = h
            # 边界处理
            standard = np.pad(standard, ((1, 1), (1, 1), (0, 0)), 'edge')
            standard[:, 0, 0] = 0
            standard[0, :, 1] = 0
            standard[-1, :, 1] = aw
            standard[:, -1, 0] = ah
            points = np.pad(points, ((1, 1), (1, 1), (0, 0)), 'edge')
            points[:, 0, 0] = 0
            points[0, :, 1] = 0
            points[-1, :, 1] = aw
            points[:, -1, 0] = ah
            # 线性方式处理边界(水平,垂直)
            cols, rows = points.shape[:2]
            for i in range(1, cols - 1):
                pa = points[i, 1, :]
                pb = points[i, 2, :]
                points[i, 0, 1] = int(pa[1] - pa[0] * (pa[1] - pb[1]) / (pa[0] - pb[0]))
                pa = points[i, -2, :]
                pb = points[i, -3, :]
                points[i, -1, 1] = int(pa[1] + (ah - pa[0]) * (pa[1] - pb[1]) / (pa[0] - pb[0]))
            for j in range(1, rows - 1):
                pa = points[1, j, :]
                pb = points[2, j, :]
                points[0, j, 0] = int(pa[0] - pa[1] * (pa[0] - pb[0]) / (pa[1] - pb[1]))
                pa = points[-2, j, :]
                pb = points[-3, j, :]
                points[-1, j, 0] = int(pa[0] + (aw - pa[1]) * (pa[0] - pb[0]) / (pa[1] - pb[1]))
            # 图像网格化矫正
            imgnew = np.zeros(img.shape, dtype=np.uint8)
            for i in range(1, cols):
                for j in range(1, rows):
                    height, weight = standard[i, j] - standard[i - 1, j - 1]
                    locy, locx = standard[i - 1, j - 1]
                    box_std = np.array(((0, 0), (weight, 0), (0, height), (weight, height)), dtype=np.float32)
                    pa, pb, pc, pd = points[i - 1, j - 1], points[i, j - 1], points[i - 1, j], points[i, j]
                    box_origin = np.array((pa[::-1], pb[::-1], pc[::-1], pd[::-1]), dtype=np.float32)
                    # 原始网格坐标,矫正网格坐标
                    M = cv2.getPerspectiveTransform(box_origin, box_std)
                    tmp = cv2.warpPerspective(img, M, (weight, height))
                    # showImageNormal(tmp)
                    imgnew[locy:locy + height, locx:locx + weight] = tmp
            return imgnew
        except:
            return img
        # # 作图
        # for i in range(cols):
        #     for j in range(rows):
        #         y, x = points[i, j]
        #         cv2.circle(img, (x, y), 5, [0, 0, 255], 3)
        #         # cv2.putText(img, "%.2d" % (i * rows + j), (x, y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 3)
        #
        #         y, x = standard[i, j]
        #         cv2.circle(imgnew, (x, y), 5, [255, 0, 0], 3)
        #         cv2.circle(img, (x, y), 5, [0, 255, 0], 3)
        #         # cv2.putText(imgnew, "%.2d" % (i * rows + j), (x, y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 3)
        #
        # imgnew = np.vstack([img, imgnew])
        # showImage(imgnew)

  • 1
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 9
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值