表格的扭曲矫正
在通过ocr对表格中的图像进行识别的时候发现,表格图像的定位是很重要的.虽然通过hough找直线,设低阈值,可以较高鲁棒性的找出表格的横线,但是表格如果存在一定程度的扭曲偏移,则检测切图的效果很差劲.看些论文发现,基于函数的,或者相机视角的矫正确实效果好.但是就目前的基于表格的应用来说还是越简单越好.
以上是矫正的效果图.思路很简单,
- 通过十字模板去匹配表格交点.
- 通过hough找出的直线的相对位置过滤掉不符合的交点.
- 通过按列按行,对与交点进行排序
- 通过交点的均值,找出标准网格中,交点的排布位置
- 遍历网格,对网格单元进行从原图像到标准图像的映射转换
原图片
代码
class Choose:
# 根据十字获取,表格交点坐标点
def detectCross(self, imgt, hsize, wsize, thr=0.65):
# 十字中心
midh, midw = hsize // 2, wsize // 2
# 十字模板
template = np.zeros([hsize, wsize])
template[midh - 1:midh + 2, :] = 255
template[:, midw - 1:midw + 2] = 255
match_res = cv2.matchTemplate(np.array(imgt, dtype=np.uint8), np.array(template, dtype=np.uint8),
cv2.TM_CCORR_NORMED)
# 根据阈值筛选顶点
index = np.where(match_res > thr)
points = []
for y, x in zip(index[0], index[1]):
points.append([y + midh, x + midw])
return points
# 获取表格网格点坐标点
def getNetPoints(self, img):
ah, aw = img.shape[:2]
img = np.array(img)
imgt = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (10, 10))
imgt = cv2.erode(imgt, kernel, iterations=1)
imgt = cv2.dilate(imgt, kernel, iterations=1)
ret, _ = cv2.threshold(imgt, 0, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)
ret, imgt = cv2.threshold(imgt, ret * 1.2, 255, cv2.THRESH_BINARY_INV)
# showImage(imgt)
# # 根据阈值大小和十字图像模板的长宽,查找对应点,寻找合适参数
# for size in range(30, 95, 5):
# points = self.detectCross(imgt, size, size)
# imgx = np.array(img)
# for y, x in points:
# cv2.circle(imgx, (x, y), 5, [0, 0, 255], 1)
# showImage(imgx, str(size))
points = self.detectCross(imgt, 40, 45, 0.65) # 获取表格点
# self.drawPoints(img, points)
mask = np.zeros(imgt.shape[:2])
for y, x in points:
mask[y, x] = 1
needh, needw = self.findDetailBox(img) # 根据直线获取,表格参考线
warr = sorted(list(set(np.array(needw).flatten())))
for w in warr:
sw, ew = w - 10, w + 10
mask[:, sw:ew] = mask[:, sw:ew] + 1
# 保留合适点竖排点
needs_points = []
for y, x in points:
if mask[y, x] >= 2:
needs_points.append([y, x])
# self.drawPoints(img, needs_points)
# for h1, h2 in needh:
# for w1, w2 in needw:
# cv2.rectangle(img, (w1, h1), (w2, h2), (0, 0, 255), 2)
# showImage(img)
# 按照先列后行的顺序,进行网格排序
needs_points = np.array(needs_points)
col_scale = []
mg = 8
col_scale.append([needw[0][0] - mg, needw[0][0] + mg])
for i in range(1, len(needw)):
front, cur = needw[i - 1], needw[i]
col_scale.append([front[1] - mg, cur[0] + mg])
gridx = []
tmppoints = np.array(needs_points)[:, 0].flatten()
minv, maxv = tmppoints.min(), tmppoints.max()
scale = (maxv - minv) // len(needh)
# print(scale, minv, maxv, len(needh))
# idx = 0
for s, e in col_scale: # 过滤不在范围内的,按照列行,排列,与标准线最近的点
where = np.where((needs_points[:, 1] > s) & (needs_points[:, 1] < e))
arr = needs_points[where]
if len(arr) == 0:
continue
arr = arr[np.argsort(arr[:, 0])]
if len(arr) != len(needw):
tarr = []
for vh in range(minv, maxv + 1, scale):
dx = np.abs(arr[:, 0] - vh)
if len(np.where(dx < 30)[0]) == 0:
tmp = [vh, int(arr[:, 1].mean())]
else:
tmp = arr[dx.argmin(), :]
tarr.append(tmp)
arr = tarr
gridx.append(arr)
gridx = np.array(gridx)
# print(gridx.shape)
# points = gridx.reshape([-1, 2])
# self.drawPoints(img, points)
return gridx
# 画图看网格定位结果
def drawPoints(self, img, points):
print("size of points: ", len(points))
for y, x in points:
cv2.circle(img, (x, y), 3, (0, 0, 255), 3)
showImage(img)
# 表格网格化去扭曲
def dewarp(self, img):
try:
ah, aw = img.shape[:2]
points = self.getNetPoints(img)
# print(points.shape)
# showImage(img)
cols, rows = points.shape[:2]
standard = np.zeros(points.shape, dtype=np.int) # 模板网格点
# 根据表格点的均值,作为参考点
for i in range(cols):
w = int(points[i, :, 1].mean())
standard[i, :, 1] = w
for j in range(rows):
h = int(points[:, j, 0].mean())
standard[:, j, 0] = h
# 边界处理
standard = np.pad(standard, ((1, 1), (1, 1), (0, 0)), 'edge')
standard[:, 0, 0] = 0
standard[0, :, 1] = 0
standard[-1, :, 1] = aw
standard[:, -1, 0] = ah
points = np.pad(points, ((1, 1), (1, 1), (0, 0)), 'edge')
points[:, 0, 0] = 0
points[0, :, 1] = 0
points[-1, :, 1] = aw
points[:, -1, 0] = ah
# 线性方式处理边界(水平,垂直)
cols, rows = points.shape[:2]
for i in range(1, cols - 1):
pa = points[i, 1, :]
pb = points[i, 2, :]
points[i, 0, 1] = int(pa[1] - pa[0] * (pa[1] - pb[1]) / (pa[0] - pb[0]))
pa = points[i, -2, :]
pb = points[i, -3, :]
points[i, -1, 1] = int(pa[1] + (ah - pa[0]) * (pa[1] - pb[1]) / (pa[0] - pb[0]))
for j in range(1, rows - 1):
pa = points[1, j, :]
pb = points[2, j, :]
points[0, j, 0] = int(pa[0] - pa[1] * (pa[0] - pb[0]) / (pa[1] - pb[1]))
pa = points[-2, j, :]
pb = points[-3, j, :]
points[-1, j, 0] = int(pa[0] + (aw - pa[1]) * (pa[0] - pb[0]) / (pa[1] - pb[1]))
# 图像网格化矫正
imgnew = np.zeros(img.shape, dtype=np.uint8)
for i in range(1, cols):
for j in range(1, rows):
height, weight = standard[i, j] - standard[i - 1, j - 1]
locy, locx = standard[i - 1, j - 1]
box_std = np.array(((0, 0), (weight, 0), (0, height), (weight, height)), dtype=np.float32)
pa, pb, pc, pd = points[i - 1, j - 1], points[i, j - 1], points[i - 1, j], points[i, j]
box_origin = np.array((pa[::-1], pb[::-1], pc[::-1], pd[::-1]), dtype=np.float32)
# 原始网格坐标,矫正网格坐标
M = cv2.getPerspectiveTransform(box_origin, box_std)
tmp = cv2.warpPerspective(img, M, (weight, height))
# showImageNormal(tmp)
imgnew[locy:locy + height, locx:locx + weight] = tmp
return imgnew
except:
return img
# # 作图
# for i in range(cols):
# for j in range(rows):
# y, x = points[i, j]
# cv2.circle(img, (x, y), 5, [0, 0, 255], 3)
# # cv2.putText(img, "%.2d" % (i * rows + j), (x, y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 3)
#
# y, x = standard[i, j]
# cv2.circle(imgnew, (x, y), 5, [255, 0, 0], 3)
# cv2.circle(img, (x, y), 5, [0, 255, 0], 3)
# # cv2.putText(imgnew, "%.2d" % (i * rows + j), (x, y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 3)
#
# imgnew = np.vstack([img, imgnew])
# showImage(imgnew)