目录
引言
- LGPMA: Complicated Table Structure Recognition with Local and Global Pyramid Mask Alignment 是海康威视在ICDAR2021 Table Recognition赛道获得冠军的方案。该方案对应的源码已经开源。这波操作算是很良心了。
- PDF | Code
- 该论文采用方法可以分为深度学习下从目标检测角度做的表格识别方法一栏。具体可以参见:OCR之表格结构识别综述 。
- 代码是集成到海康的DAVAR OCR工具箱中,其中LGPMA中基础backbone的实现基于MMDet和MMCV模块,当然,采用的深度学习框架自然是PyTorch。
- 本篇文章将从论文与代码一一对应解析的方式来撰写,这样便于找到论文重点地方以及用代码如何实现的,更快地学到其中要点。
- 如果读者可以阅读英文的话,建议先去直接阅读英文论文,会更直接看到整个面貌。
2022-06-08 update
- 整理开源基于LGPMA官方代码得到的推理代码仓库LGPMA_Infer 。
- 同时在该仓库下,尝试了转onnx模型,虽说成功转换,但是由于转换过程和推理过程中耗费内存太大(128G内存都没够),遂放弃。
LGPMA整体结构
- 由论文以及框图,可以将整体结构分为5部分,训练阶段有3部分(Aligned Bounding Box Detectection、LPMA、GPMA),推理阶段有2部分(Aligned Bounding Box Refine、Table Structure Recovery)。 下面我将一一作解读。
- 整体结构对应的代码位于
demo/table_recognition/lgpma/configs/lgpma_base.py
,可直接点击查看。MMDet系列所有的配置文件均是通过py文件中字典方式指定。不过,个人认为没有yaml格式文件更直观一些。
训练阶段
Aligned Boudning Box Detection(对齐的包围框检测)
- 该分支是直接用来检测包围框对齐的非空cell的。举个例子来说,下图中从每一列来看,每个单元格的红框长度和高度都基本一致且每个单元格中有值,这就是aligned bounding box for non-empty cells的意思。
- 为什么要用aligned bouding box作训练呢?
答:假如我们可以获得aligned cell,同时整个表格没有空的单元格,这时,根据每个单元格在行和列方向上的坐标值,我们很容易就可以得到cell之间的关系,也就很容易还原整个表格。 - aligned bounding box的标签如何获得呢?
答:根据已有表格中文本区域标注的数据,可以通过取每一行中高最大的文本框值作为该行cell的高,取每一列最宽的文本框值作为该列cell的宽。这样就可以获得aligned bouding box的标注数据。 - 因为在训练时,cell框的标签就是对齐的,所以最终推理所得同列的框都是基本一致的。
- 在训练过中,由于某些表格数据中存在空的cell,这就使得该分支不容易很好地学到有效的信息,这也就引出了下面的两个分支LPMA和GPMA。
- 该部分代码主要是采用mmdetection的接口实现,这里给出配置文件源码位置,具体每个接口的使用,可以参见mmdetection文档。
roi_head=dict( type='LGPMARoIHead', bbox_roi_extractor=dict( type='SingleRoIExtractor', roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), out_channels=256, featmap_strides=[4, 8, 16, 32]), bbox_head=dict( type='Shared2FCBBoxHead', in_channels=256, fc_out_channels=1024, roi_feat_size=7, num_classes=2, # 这里应该是是否有文本的二分类 bbox_coder=dict( type='DeltaXYWHBBoxCoder', target_means=[0., 0., 0., 0.], target_stds=[0.1, 0.1, 0.2, 0.2]), reg_class_agnostic=False, loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)), mask_roi_extractor=dict( type='SingleRoIExtractor', roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), out_channels=256, featmap_strides=[4, 8, 16, 32]), mask_head=dict( type='LPMAMaskHead', # 这里是LPMAHead配置文件 num_convs=4, in_channels=256, conv_out_channels=256, num_classes=2, loss_mask=dict( type='CrossEntropyLoss', use_mask=True, loss_weight=1.0), loss_lpma=dict( type='L1Loss', loss_weight=1.0))),
Local Pyramid Mask Alignment (LPMA)
- 该分支针对的是每个cell中的text region (文本区域)。
- 分为两部分:
- 第一部分是二值化分割任务,只用来判断当前区域是否为text region;可以从loss看到有对应实现。
- 第二部分是pyramid mask regression,对预测得到的bounding boxes,分别在水平和垂直方向上赋于soft-label,该方法出自Pyramid mask text detector。
- 具体参见下图,简单来说,就是在下图蓝色框中,Text出现的红色框周围像素值赋予更高的权重。
- 之所以这样做的,是因为使用soft-label segmentation也许可以打破proposed bounding box的限制,同时预测到更加精准的aligned bounding boxes。
- 代码位于davarocr/davar_table/models/roi_heads/mask_heads/lpma_mask_head.py
- 其中获得local pyramid mask的代码位于davarocr/davar_table/core/mask/lp_mask_target.py,该部分代码可直接对应论文中的公式(1):
t h ( ω , h ) = { w / x 1 w ≤ x m i d W − w W − x 2 w > x m i d t v ( ω , h ) = { h / x 1 h ≤ y m i d H − h H − y 2 h > h m i d t_{h}^{(\omega, h)} = \left\{ \begin{aligned} & w/x_{1} & w\leq x_{mid} \\ &\frac{W - w}{W - x_{2}} & w > x_{mid} \end{aligned} \right. \quad t_{v}^{(\omega, h)} = \left\{ \begin{aligned} & h/x_{1} & h \leq y_{mid} \\ &\frac{H - h}{H - y_{2}} & h > h_{mid} \end{aligned} \right. th(ω,h)=⎩⎪⎨⎪⎧w/x1W−x2W−ww≤xmidw>xmidtv(ω,h)=⎩⎪⎨⎪⎧h/x1H−y2H−hh≤ymidh>hmid - 摘抄部分代码如下
# davarocr/davar_table/core/mask/lp_mask_target.py#L57 # Calculate the pyramid mask in horizontal direction col_np = np.arange(x_min, x_max + 1).reshape(1, -1) col_np_1 = (col_np[:, :middle_x - x_min] - left_col) / (middle_x - left_col) col_np_2 = (right_col - col_np[:, middle_x - x_min:]) / (right_col - middle_x) col_np = np.concatenate((col_np_1, col_np_2), axis=1) mask_s1[ind, y_min:y_max + 1, x_min:x_max + 1] = col_np # Calculate the pyramid mask in vertical direction row_np = np.arange(y_min, y_max + 1).reshape(-1, 1) row_np_1 = (row_np[:middle_y - y_min, :] - left_row) / (middle_y - left_row) row_np_2 = (right_row - row_np[middle_y - y_min:, :]) / (right_row - middle_y) row_np = np.concatenate((row_np_1, row_np_2), axis=0) mask_s2[ind, y_min:y_max + 1, x_min:x_max + 1] = row_np
Global Pyramid Mask Alignment (GPMA)
- 因为LPMA分支感受野十分有限,因此考虑加入全局特征。值得注意的是,只有该分支学习了非空的cell的信息,因为空的cell不存在text region,无法在LPMA中学习。
- 由于每个单元格中宽高比变化比较大,在回归学习任务中,这往往会带来很大的不平衡问题,所以采用和LPMA相同的策略:分为两个同时进行的任务,全局分割任务和全局的pyramid mask regression任务。
- 该部分代码位于davarocr/davar_table/models/seg_heads/gpma_mask_head.py。其中获得pyramid masks部分,是先补齐空的单元格大小,同时对非空的单元格采用和LPMA同样的做法。摘抄部分代码如下:
# davarocr/davar_table/models/seg_heads/gpma_mask_head.py#L228 mask_pred_ = mask_pred[i, 0, :, :] mask_pred_resize = mmcv.imresize(mask_pred_, (w_pad, h_pad)) mask_pred_resize = mmcv.imresize(mask_pred_resize[:h_img, :w_img], (w_ori, h_ori)) mask_pred_resize = (mask_pred_resize > 0.5) cell_region_mask.append(mask_pred_resize) # 先补齐,再获得水平和竖直的mask reg_pred1_ = reg_pred[i, 0, :, :] reg_pred2_ = reg_pred[i, 1, :, :] reg_pred1_resize = mmcv.imresize(reg_pred1_, (w_pad, h_pad)) reg_pred1_resize = mmcv.imresize(reg_pred1_resize[:h_img, :w_img], (w_ori, h_ori)) reg_pred2_resize = mmcv.imresize(reg_pred2_, (w_pad, h_pad)) reg_pred2_resize = mmcv.imresize(reg_pred2_resize[:h_img, :w_img], (w_ori, h_ori)) gp_mask_hor.append(reg_pred1_resize) gp_mask_ver.append(reg_pred2_resize)
推理阶段
该阶段主要分为两个阶段,首先获得refined后的aligned bounding boxes,然后由structure recovery pipeline来还原表格。
Inference: Aligned Bounding Box Refine (微调预测所得检测框)
- 该部分主要是由于训练阶段采用的pyramid label,整个部分可以参考论文:Pyramid mask text detector,在该论文中有详细说明。
- 实话说,该部分暂时还没完全看懂,感兴趣的小伙伴可以直接去看论文的3.6小节。
- 尝试按照论文中的公式(7)与代码对应实现做了比对,发现代码实现与公式(7)并不对应。
- 公式(7):
x r e f i n e = − 1 y 2 − y 1 + 1 ∑ y i = y 1 y 2 b y i + c a x_{refine} = - \frac{1}{y_{2} - y_{1} +1}\sum_{y_{i} = y_{1}}^{y_{2}}\frac{by_{i} + c}{a} xrefine=−y2−y1+11yi=y1∑y2abyi+c - 对应该块实现的代码:
def refine_x(xmin, xmax, ymin, ymax): """Refining left boundary or right boundary. Args: xmin(int): left boundary of original aligned bboxes. xmax(int): right boundary of original aligned bboxes. ymin(int): top boundary of original aligned bboxes. ymax(int): lower boundary of original aligned bboxes. Returns: int: the refined boundary. """ a_sum = get_matrix(xmin, xmax, ymin, ymax) z_sum = get_vector(xmin, xmax, ymin, ymax, soft_mask[0]) try: (a, b, c) = np.dot(a_sum.I, z_sum) except: return -1 y_mean = (ymax + ymin) / 2 x_refine = int((-1 * c / a - y_mean * b / a) + 0.5) return x_refine
- 将最后一行代码转换为公式,即为:
x r e f i n e = 1 2 − b y ˉ + c a ( 1 ) x_{refine} = \frac{1}{2} - \frac{b\bar{y} + c}{a} \quad (1) xrefine=21−abyˉ+c(1)
其中, y ˉ = y m a x + y m i n 2 \bar{y} = \frac{y_{max} + y_{min}}{2} yˉ=2ymax+ymin - 将论文中公式(7)展开得到如下:
x r e f i n e = − 1 y 2 − y 1 + 1 b ( y 1 + y 2 ) + 2 c a = − 1 y 2 − y 1 + 1 2 ( b y ˉ + c ) a ( 2 ) x_{refine} = - \frac{1}{y_{2} - y_{1} + 1}\frac{b(y_1 + y_2) + 2c}{a} \\ = - \frac{1}{y_{2} - y_{1} + 1} \frac{2(b\bar{y} + c)}{a} \quad (2) xrefine=−y2−y1+11ab(y1+y2)+2c=−y2−y1+11a2(byˉ+c)(2) - 文中公式(1)和公式(2)并不相等!!
- 公式(7):
Inference: Table Structure Recovery (表格结构还原)
-
该部分主要分为三步:单元格匹配、空单元格搜寻和空单元格融合。
-
单元格匹配。 思路很简单,如果两个对齐后的cell框,在x/y轴上有着很大的重叠部分,我们就有理由认为它们是在同一列/行上。对应代码位于:davarocr/davar_table/core/post_processing/post_lgpma.py#L355-L364
# Calculating cell adjacency matrix according to bboxes of non-empty aligned cells bboxes_np = np.array(bboxes) adjr, adjc = bbox2adj(bboxes_np) # Predicting start and end row / column of each cell according to the cell adjacency matrix colspan = adj_to_cell(adjc, bboxes_np, 'col') rowspan = adj_to_cell(adjr, bboxes_np, 'row') cells = [[row.min(), col.min(), row.max(), col.max()] for col, row in zip(colspan, rowspan)] cells = [list(map(int, cell)) for cell in cells] cells_np = np.array(cells)
-
空单元格搜寻。 将aligned bounding boxes作为节点,它们之间的关系作为边。所有在同一行/列的节点构成了一个子图,采用Maximum Clique Search算法。我们要知道什么时候表格中会出现空的cell?答案是出现单元格合并情况时。
- 具体以行搜索过程为例讲解原理。具体列和行子图示意图,下图来自论文:Rethinking table recognition using graph neural networks,图中中间部分即是行子图,可以看到的是TriStar节点出现在了两个子图中。
- 当一个合并后的单元格跨多个行时,其相应的节点肯定会在多个子图中出现,就像上图中的TriStar单元格。
- 将所有的行子图按照y轴排序,很容易定位到每个节点的行索引,而那些出现在多个子图中的节点,也会被标记到多个行索引上。
- 由此,可以认定出现在多个行索引中的节点即是空的cell,或者说是空的cell的一部分(合并单元格前的独立单元格)。
- 对应源码位于:davarocr/davar_table/core/post_processing/post_lgpma.py#L25
from networkx import Graph, find_cliques def adj_to_cell(adj, bboxes, mod): """Calculating start and end row / column of each cell according to row / column adjacent relationships Args: adj(np.array): (n x n). row / column adjacent relationships of non-empty aligned cells bboxes(np.array): (n x 4). bboxes of non-empty aligned cells mod(str): 'row' or 'col' Returns: list(np.array): start and end row of each cell if mod is 'row' / start and end col of each cell if mod is 'col' """ assert mod in ('row', 'col') # generate graph of each non-empty aligned cells nodenum = adj.shape[0] edge_temp = np.where(adj != 0) edge = list(zip(edge_temp[0], edge_temp[1])) # 采用的networkx中图和find_cliques函数接口 table_graph = Graph() table_graph.add_nodes_from(list(range(nodenum))) table_graph.add_edges_from(edge) # Find maximal clique in the graph clique_list = list(find_cliques(table_graph)) # 省略部分代码
- 具体以行搜索过程为例讲解原理。具体列和行子图示意图,下图来自论文:Rethinking table recognition using graph neural networks,图中中间部分即是行子图,可以看到的是TriStar节点出现在了两个子图中。
-
空单元格合并。 经过上述几个步骤,我们可以获得空的单元格位置,注意这里获得的空的单元格仅仅只是最小的单元格,并非是合并后的。
- 为了更加可靠的方式合并这些单元格,首先将这些空的单元格的大小设定为同行或同列单元格中最大的宽和高。
- 随后,计算两个相邻的单元格内部之间,被预测为1的像素比例,如下图中红框所示。如果比例占用大于设定的阈值,则将这两个cell合并为一个。
-
源码:davarocr/davar_table/core/post_processing/post_lgpma.py#L366-L377
# Searching empty cells and recording them through arearec arearec = np.zeros([cells_np[:, 2].max() + 1, cells_np[:, 3].max() + 1]) for cellid, rec in enumerate(cells_np): srow, scol, erow, ecol = rec[0], rec[1], rec[2], rec[3] arearec[srow:erow + 1, scol:ecol + 1] = cellid + 1 empty_index = -1 # deal with empty cell for row in range(arearec.shape[0]): for col in range(arearec.shape[1]): if arearec[row, col] == 0: cells.append([row, col, row, col]) arearec[row, col] = empty_index empty_index -= 1
-
推理最终的生成结果是一个字典形式,举例如下:
{ 'html': '<html><body><table><thead><tr><td colspan="3"></td><td></td><td></td><td></td><td rowspan="2"></td></tr><tr><td></td><td></td><td></td><td></td><td></td><td></td></tr></thead><tbody><tr><td></td><td></td><td></td><td></td><td></td><td></td><td></td></tr><tr><td></td><td></td></tbody></table></body></html>', 'content_ann': { 'bboxes': [[10, 9, 216, 29], [], [], [], [642, 6, 817, 80], [8, 40, 120, 80], ], 'labels': [[0], [0], [1], [0], [1], [0]], 'texts': ['', '', '', '', '', ''] } }
总结
- LGPMA算法整体思路清晰,行文逻辑清晰,值得学习。
- 由于源码是基于mmdetection修改而来的,所以整个复现环境有些繁琐,不过该仓库的作者维护还是十分及时的。
- 本篇文章涉及东西较多,难免挂一漏万,如果哪里写的不当,还请指出。