论文阅读: Spatial Dual-Modality Graph Reasoning for Key Information Extraction (关键信息提取算法)

引言
  • 文档图像中的关键信息提取任务(Key Information Extraction, KIE)是实现办公场景自动化的一项重要任务。
  • 如果从OCR场景来看,KIE任务可以作为对OCR提取结果内容的结构化抽取来使用。
  • 这次介绍的论文是Spatial Dual-Modality Graph Reasoning for Key Information Extraction,该论文是在去年3月份挂在arxiv上的,但是具体作者机构暂时没有给出。从论文arxiv主页上来看,其官方Code的链接是MMOCR。大胆猜测,难道作者是商汤的?这就不得而知了。
  • 值得一提的是,PaddleOCR最近也集成了该算法,看来关键信息抽取这一研究方向得到了工业界的一些重视,距离真正落地使用不会太远了。
  • 由于该算法在MMOCR和PaddleOCR中均有实现,考虑到跑通示例程序的便利性,这里以PaddleOCR中实现为例,指出相应的实现源码,用作学习之用。
SDMG-R整体结构

SDMG-R framework

  • 从结构图来看,论文的思路比较清晰。整体结构可分为三个模块:双模态融合模块图推理模块分类模块三个。
双模态融合模块

该模块结合视觉特征和文本特征。其中视觉特征 v i {v_{i}} vi来自U-NetROI-Pooling提取所得,文本特征 t i {t_{i}} ti则是通过Bi-LSTM提取得到的。两个不同模态的特征通过Kronecker乘积操作得以融合。这样得以充分利用图像的一维和二维信息,这也是Dual Modality名称的由来。

Backbone部分

在PaddleOCR中backbone:U-Net,位于ppocr/modeling/backbones/kie_unet_sdmgr.py,主要包括U-NetROI Align两部分。对应的源码如下(省略部分代码),重点在于forward部分:

class Kie_backbone(nn.Layer):
    def __init__(self, in_channels, **kwargs):
        super(Kie_backbone, self).__init__()
        self.out_channels = 16
        self.img_feat = UNet()
        self.maxpool = nn.MaxPool2D(kernel_size=7)

    def bbox2roi(self, bbox_list):
        pass

    def pre_process(self, img, relations, texts, gt_bboxes, tag, img_size):
        pass

    def forward(self, inputs):
        # img shape: [1, 3, 1024, 1024]
        img = inputs[0]
        
        relations, texts = inputs[1], inputs[2]
        gt_bboxes, tag, img_size = inputs[3], inputs[5], inputs[-1]
        
        # 前处理
        img, relations, texts, gt_bboxes = self.pre_process(
            img, relations, texts, gt_bboxes, tag, img_size)
        
        # 此时img shape: [1, 3, 512, 512]
        x = self.img_feat(img)
        # output x shape: [1, 16, 512, 512]
        
        boxes, rois_num = self.bbox2roi(gt_bboxes)
        
        feats = paddle.fluid.layers.roi_align(
            x,
            boxes,
            spatial_scale=1.0,
            pooled_height=7,
            pooled_width=7,
            rois_num=rois_num)
        
        # feats shape: [26, 16, 7, 7]
        feats = self.maxpool(feats).squeeze(-1).squeeze(-1)
        
        # output feats shape: [26, 16]
        return [relations, texts, feats]
Head部分

该部分是文本特征的提取,以及backbone部分提取所得图像特征和文本特征的合并部分。这部分主要集中在PaddleOCR代码中head部分下

class SDMGRHead(nn.Layer):
    def __init__(self,
                 in_channels,
                 num_chars=92,
                 visual_dim=16,
                 fusion_dim=1024,
                 node_input=32,
                 node_embed=256,
                 edge_input=5,
                 edge_embed=256,
                 num_gnn=2,
                 num_classes=26,
                 bidirectional=False):
        super().__init__()
		
		# 融合模块
        self.fusion = Block([visual_dim, node_embed], node_embed, fusion_dim)

        self.node_embed = nn.Embedding(num_chars, node_input, 0)

        hidden = node_embed // 2 if bidirectional else node_embed

		# 单层LSTM模块
        self.rnn = nn.LSTM(
            input_size=node_input, hidden_size=hidden, num_layers=1)

		# 图推理模块
        self.edge_embed = nn.Linear(edge_input, edge_embed)
        self.gnn_layers = nn.LayerList(
            [GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)])

		# 分类模块
        self.node_cls = nn.Linear(node_embed, num_classes)
        self.edge_cls = nn.Linear(edge_embed, 2)

    def forward(self, input, targets):
        relations, texts, x = input
        node_nums, char_nums = [], []
        for text in texts:
            node_nums.append(text.shape[0])
            char_nums.append(paddle.sum((text > -1).astype(int), axis=-1))

        max_num = max([char_num.max() for char_num in char_nums])
        all_nodes = paddle.concat([
            paddle.concat(
                [text, paddle.zeros(
                    (text.shape[0], max_num - text.shape[1]))], -1)
            for text in texts
        ])
        temp = paddle.clip(all_nodes, min=0).astype(int)
        embed_nodes = self.node_embed(temp)
        rnn_nodes, _ = self.rnn(embed_nodes)

        b, h, w = rnn_nodes.shape
        nodes = paddle.zeros([b, w])
        all_nums = paddle.concat(char_nums)
        valid = paddle.nonzero((all_nums > 0).astype(int))
        temp_all_nums = (
            paddle.gather(all_nums, valid) - 1).unsqueeze(-1).unsqueeze(-1)
        temp_all_nums = paddle.expand(temp_all_nums, [
            temp_all_nums.shape[0], temp_all_nums.shape[1], rnn_nodes.shape[-1]
        ])
        temp_all_nodes = paddle.gather(rnn_nodes, valid)
        N, C, A = temp_all_nodes.shape
        one_hot = F.one_hot(
            temp_all_nums[:, 0, :], num_classes=C).transpose([0, 2, 1])

        one_hot = paddle.multiply(
            temp_all_nodes, one_hot.astype("float32")).sum(axis=1,
                                                           keepdim=True)
        t = one_hot.expand([N, 1, A]).squeeze(1)
        nodes = paddle.scatter(nodes, valid.squeeze(1), t)
		
		# 图像特征和文本特征融合
        if x is not None:
            nodes = self.fusion([x, nodes])

        all_edges = paddle.concat(
            [rel.reshape([-1, rel.shape[-1]]) for rel in relations])
        embed_edges = self.edge_embed(all_edges.astype('float32'))
        embed_edges = F.normalize(embed_edges)

		# 将节点特征和边的权重信息整合到一起
		# 图推理模块
        for gnn_layer in self.gnn_layers:
            nodes, cat_nodes = gnn_layer(nodes, embed_edges, node_nums)

 		# 分类模块
        node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(cat_nodes)
        return node_cls, edge_cls
  • 不过,从MMOCR中实现的代码来看,并没有采用Bi-LSTM,而是采用的是LSTM。这一点从Issue #491中,可以看到开发者对此的说法是:使用Bi-LSTM和单独使用LSTM并没有多大区别。
  • 同时,我又去比对了PaddleOCR中此部分的实现,同样也是采用的LSTM。严重怀疑,PaddleOCR的该算法代码是转写的MMOCR的。😂
融合模块
  • 该部分主要参考论文中将到的方法实现的,通过将高阶矩阵拆分了几个低阶的小矩阵,从而降低系统占用内存和计算时间
  • 源码部分
    class Block(nn.Layer):
    def __init__(self,
                 input_dims,
                 output_dim,
                 mm_dim=1600,
                 chunks=20,
                 rank=15,
                 shared=False,
                 dropout_input=0.,
                 dropout_pre_lin=0.,
                 dropout_output=0.,
                 pos_norm='before_cat'):
        super().__init__()
        self.rank = rank
        self.dropout_input = dropout_input
        self.dropout_pre_lin = dropout_pre_lin
        self.dropout_output = dropout_output
        assert (pos_norm in ['before_cat', 'after_cat'])
        self.pos_norm = pos_norm
    
        # Modules
        self.linear0 = nn.Linear(input_dims[0], mm_dim)
        self.linear1 = (self.linear0
                        if shared else nn.Linear(input_dims[1], mm_dim))
    
        self.merge_linears0 = nn.LayerList()
        self.merge_linears1 = nn.LayerList()
        self.chunks = self.chunk_sizes(mm_dim, chunks)
        for size in self.chunks:
            ml0 = nn.Linear(size, size * rank)
            self.merge_linears0.append(ml0)
    
            ml1 = ml0 if shared else nn.Linear(size, size * rank)
            self.merge_linears1.append(ml1)
        self.linear_out = nn.Linear(mm_dim, output_dim)
    
    def forward(self, x):
        x0 = self.linear0(x[0])
        x1 = self.linear1(x[1])
        bs = x1.shape[0]
        if self.dropout_input > 0:
            x0 = F.dropout(x0, p=self.dropout_input, training=self.training)
            x1 = F.dropout(x1, p=self.dropout_input, training=self.training)
        x0_chunks = paddle.split(x0, self.chunks, -1)
        x1_chunks = paddle.split(x1, self.chunks, -1)
    
    	# 这里是转换为低阶的矩阵乘法的关键
        zs = []
        for x0_c, x1_c, m0, m1 in zip(x0_chunks, x1_chunks, self.merge_linears0,
                                      self.merge_linears1):
            m = m0(x0_c) * m1(x1_c)  # bs x split_size*rank
            m = m.reshape([bs, self.rank, -1])
            z = paddle.sum(m, 1)
            if self.pos_norm == 'before_cat':
                z = paddle.sqrt(F.relu(z)) - paddle.sqrt(F.relu(-z))
                z = F.normalize(z)
            zs.append(z)
        z = paddle.concat(zs, 1)
    
        if self.pos_norm == 'after_cat':
            z = paddle.sqrt(F.relu(z)) - paddle.sqrt(F.relu(-z))
            z = F.normalize(z)
    
        if self.dropout_pre_lin > 0:
            z = F.dropout(z, p=self.dropout_pre_lin, training=self.training)
        z = self.linear_out(z)
        
        if self.dropout_output > 0:
            z = F.dropout(z, p=self.dropout_output, training=self.training)
        return z
    
    def chunk_sizes(self, dim, chunks):
        split_size = (dim + chunks - 1) // chunks
        sizes_list = [split_size] * chunks
        sizes_list[-1] = sizes_list[-1] - (sum(sizes_list) - dim)
        return sizes_list
    
文本节点与边权重获得部分
  • 该模块主要将文本框信息与文本信息之间构建图的关系,便于基于位置来推理文本之间的关系。
  • 论文中介绍的可是相当清楚,同时结合源码对应来看,自己就会有豁然开朗的感觉。
  • 先说论文中的公式介绍部分,主要集中在论文的公式(5~14)
    • 论文中将文档图像作为一个图来看待, g = ( N , E ) \mathcal{g} = (\mathcal{N}, \mathcal{E}) g=(N,E),其中 N = n i \mathcal{N} = {n_{i}} N=ni, n i n_{i} ni是由text node r i r_{i} ri的特征向量组成, E = e i j \mathcal{E} = e_{ij} E=eij e i j e_{ij} eij r i r_{i} ri r j r_{j} rj之间的连接权重。
    • text node之间的spatial relation的计算主要由论文中公式(5-9)给出:
      spatial_relations计算
    • 该部分对应源码为:
    def compute_relation(self, boxes):
        """Compute relation between every two boxes."""
        # 公式(5)
        x1s, y1s = boxes[:, 0:1], boxes[:, 1:2]
    
    	# 公式(6)
        x2s, y2s = boxes[:, 4:5], boxes[:, 5:6]
        
        # 公式(7、8、9)
        ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1)
        dxs = (x1s[:, 0][None] - x1s) / self.norm
        dys = (y1s[:, 0][None] - y1s) / self.norm
    
        xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs
        whs = ws / hs + np.zeros_like(xhhs)
        relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1)
        bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32)
        return relations, bboxes
    
    • 随后,将文本节点之间的信息嵌入到边的权重之中,具体按照下面公式:-spatial_embedded
    • 该部分对应源码主要位于GNNLayer类中。
图推理模块

论文中提到该部分主要是以迭代的方式来逐步优化节点特征,详细参见论文中公式(13~14):
n i l + 1 = n i l + σ ( W l ( ∑ j ≠ i α i j l e i j l ) ) α i j l = e x p ( e i j ) ∑ k ≠ i e x p ( e i k ) \begin{aligned} n_{i}^{l+1} &= n_{i}^{l} + \sigma(W^{l}(\sum \limits_{ j\neq{i} } \alpha_{ij}^{l} \bm{e}_{ij}^{l})) \\ \alpha_{ij}^{l} &= \frac{exp(e_{ij})}{\sum \limits_{k\neq{i}} exp(e_{ik})} \end{aligned} nil+1αijl=nil+σ(Wl(j=iαijleijl))=k=iexp(eik)exp(eij)

  • 对应源码部分如下:
    class GNNLayer(nn.Layer):
        def __init__(self, node_dim=256, edge_dim=256):
            super().__init__()
            self.in_fc = nn.Linear(node_dim * 2 + edge_dim, node_dim)
            self.coef_fc = nn.Linear(node_dim, 1)
            self.out_fc = nn.Linear(node_dim, node_dim)
            self.relu = nn.ReLU()
    
        def forward(self, nodes, edges, nums):
            # 合并节点信息和边信息
            start, cat_nodes = 0, []
            for num in nums:
                sample_nodes = nodes[start:start + num]
                cat_nodes.append(
                    paddle.concat([
                        paddle.expand(sample_nodes.unsqueeze(1), [-1, num, -1]),
                        paddle.expand(sample_nodes.unsqueeze(0), [num, -1, -1])
                    ], -1).reshape([num**2, -1]))
                start += num
            cat_nodes = paddle.concat([paddle.concat(cat_nodes), edges], -1)
    
    		# 公式(13)
            cat_nodes = self.relu(self.in_fc(cat_nodes))
            coefs = self.coef_fc(cat_nodes)
    
    		# 公式(14)
            start, residuals = 0, []
            for num in nums:
                residual = F.softmax(
                    -paddle.eye(num).unsqueeze(-1) * 1e9 +
                    coefs[start:start + num**2].reshape([num, num, -1]), 1)
                    
                residuals.append((residual * cat_nodes[start:start + num**2]
                                  .reshape([num, num, -1])).sum(1))
                start += num**2
    
            nodes += self.relu(self.out_fc(paddle.concat(residuals)))
            return [nodes, cat_nodes]
    
分类模块

该部分就是两个FC层即可,一个FC对应节点,一个FC对应边。源码

总结

从落地的角度来看这篇工作,似乎并不能很好的使用。我能想到的原因主要有两个:
1. 论文中用的backbone有些厚重,在速度和精度之间没有做到一个trade-off
2. 训练这样一个模型,所需要的数据集并不容易获得,对数据集的要求较高。同时,考虑到大部分应用场景中,中英文居多。但目前并没有相关的中文数据集

  • 3
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值