引言
- 文档图像中的关键信息提取任务(Key Information Extraction, KIE)是实现办公场景自动化的一项重要任务。
- 如果从OCR场景来看,KIE任务可以作为对OCR提取结果内容的结构化抽取来使用。
- 这次介绍的论文是Spatial Dual-Modality Graph Reasoning for Key Information Extraction,该论文是在去年3月份挂在arxiv上的,但是具体作者机构暂时没有给出。从论文arxiv主页上来看,其官方Code的链接是MMOCR。大胆猜测,难道作者是商汤的?这就不得而知了。
- 值得一提的是,PaddleOCR最近也集成了该算法,看来关键信息抽取这一研究方向得到了工业界的一些重视,距离真正落地使用不会太远了。
- 由于该算法在MMOCR和PaddleOCR中均有实现,考虑到跑通示例程序的便利性,这里以PaddleOCR中实现为例,指出相应的实现源码,用作学习之用。
SDMG-R整体结构
- 从结构图来看,论文的思路比较清晰。整体结构可分为三个模块:双模态融合模块、图推理模块和分类模块三个。
双模态融合模块
该模块结合视觉特征和文本特征。其中视觉特征 v i {v_{i}} vi来自U-Net和ROI-Pooling提取所得,文本特征 t i {t_{i}} ti则是通过Bi-LSTM提取得到的。两个不同模态的特征通过Kronecker乘积操作得以融合。这样得以充分利用图像的一维和二维信息,这也是Dual Modality名称的由来。
Backbone部分
在PaddleOCR中backbone:U-Net,位于ppocr/modeling/backbones/kie_unet_sdmgr.py,主要包括U-Net和ROI Align两部分。对应的源码如下(省略部分代码),重点在于forward
部分:
class Kie_backbone(nn.Layer):
def __init__(self, in_channels, **kwargs):
super(Kie_backbone, self).__init__()
self.out_channels = 16
self.img_feat = UNet()
self.maxpool = nn.MaxPool2D(kernel_size=7)
def bbox2roi(self, bbox_list):
pass
def pre_process(self, img, relations, texts, gt_bboxes, tag, img_size):
pass
def forward(self, inputs):
# img shape: [1, 3, 1024, 1024]
img = inputs[0]
relations, texts = inputs[1], inputs[2]
gt_bboxes, tag, img_size = inputs[3], inputs[5], inputs[-1]
# 前处理
img, relations, texts, gt_bboxes = self.pre_process(
img, relations, texts, gt_bboxes, tag, img_size)
# 此时img shape: [1, 3, 512, 512]
x = self.img_feat(img)
# output x shape: [1, 16, 512, 512]
boxes, rois_num = self.bbox2roi(gt_bboxes)
feats = paddle.fluid.layers.roi_align(
x,
boxes,
spatial_scale=1.0,
pooled_height=7,
pooled_width=7,
rois_num=rois_num)
# feats shape: [26, 16, 7, 7]
feats = self.maxpool(feats).squeeze(-1).squeeze(-1)
# output feats shape: [26, 16]
return [relations, texts, feats]
Head部分
该部分是文本特征的提取,以及backbone部分提取所得图像特征和文本特征的合并部分。这部分主要集中在PaddleOCR代码中head部分下。
class SDMGRHead(nn.Layer):
def __init__(self,
in_channels,
num_chars=92,
visual_dim=16,
fusion_dim=1024,
node_input=32,
node_embed=256,
edge_input=5,
edge_embed=256,
num_gnn=2,
num_classes=26,
bidirectional=False):
super().__init__()
# 融合模块
self.fusion = Block([visual_dim, node_embed], node_embed, fusion_dim)
self.node_embed = nn.Embedding(num_chars, node_input, 0)
hidden = node_embed // 2 if bidirectional else node_embed
# 单层LSTM模块
self.rnn = nn.LSTM(
input_size=node_input, hidden_size=hidden, num_layers=1)
# 图推理模块
self.edge_embed = nn.Linear(edge_input, edge_embed)
self.gnn_layers = nn.LayerList(
[GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)])
# 分类模块
self.node_cls = nn.Linear(node_embed, num_classes)
self.edge_cls = nn.Linear(edge_embed, 2)
def forward(self, input, targets):
relations, texts, x = input
node_nums, char_nums = [], []
for text in texts:
node_nums.append(text.shape[0])
char_nums.append(paddle.sum((text > -1).astype(int), axis=-1))
max_num = max([char_num.max() for char_num in char_nums])
all_nodes = paddle.concat([
paddle.concat(
[text, paddle.zeros(
(text.shape[0], max_num - text.shape[1]))], -1)
for text in texts
])
temp = paddle.clip(all_nodes, min=0).astype(int)
embed_nodes = self.node_embed(temp)
rnn_nodes, _ = self.rnn(embed_nodes)
b, h, w = rnn_nodes.shape
nodes = paddle.zeros([b, w])
all_nums = paddle.concat(char_nums)
valid = paddle.nonzero((all_nums > 0).astype(int))
temp_all_nums = (
paddle.gather(all_nums, valid) - 1).unsqueeze(-1).unsqueeze(-1)
temp_all_nums = paddle.expand(temp_all_nums, [
temp_all_nums.shape[0], temp_all_nums.shape[1], rnn_nodes.shape[-1]
])
temp_all_nodes = paddle.gather(rnn_nodes, valid)
N, C, A = temp_all_nodes.shape
one_hot = F.one_hot(
temp_all_nums[:, 0, :], num_classes=C).transpose([0, 2, 1])
one_hot = paddle.multiply(
temp_all_nodes, one_hot.astype("float32")).sum(axis=1,
keepdim=True)
t = one_hot.expand([N, 1, A]).squeeze(1)
nodes = paddle.scatter(nodes, valid.squeeze(1), t)
# 图像特征和文本特征融合
if x is not None:
nodes = self.fusion([x, nodes])
all_edges = paddle.concat(
[rel.reshape([-1, rel.shape[-1]]) for rel in relations])
embed_edges = self.edge_embed(all_edges.astype('float32'))
embed_edges = F.normalize(embed_edges)
# 将节点特征和边的权重信息整合到一起
# 图推理模块
for gnn_layer in self.gnn_layers:
nodes, cat_nodes = gnn_layer(nodes, embed_edges, node_nums)
# 分类模块
node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(cat_nodes)
return node_cls, edge_cls
- 不过,从MMOCR中实现的代码来看,并没有采用Bi-LSTM,而是采用的是LSTM。这一点从Issue #491中,可以看到开发者对此的说法是:使用Bi-LSTM和单独使用LSTM并没有多大区别。
- 同时,我又去比对了PaddleOCR中此部分的实现,同样也是采用的LSTM。严重怀疑,PaddleOCR的该算法代码是转写的MMOCR的。😂
融合模块
- 该部分主要参考论文中将到的方法实现的,通过将高阶矩阵拆分了几个低阶的小矩阵,从而降低系统占用内存和计算时间
- 源码部分:
class Block(nn.Layer): def __init__(self, input_dims, output_dim, mm_dim=1600, chunks=20, rank=15, shared=False, dropout_input=0., dropout_pre_lin=0., dropout_output=0., pos_norm='before_cat'): super().__init__() self.rank = rank self.dropout_input = dropout_input self.dropout_pre_lin = dropout_pre_lin self.dropout_output = dropout_output assert (pos_norm in ['before_cat', 'after_cat']) self.pos_norm = pos_norm # Modules self.linear0 = nn.Linear(input_dims[0], mm_dim) self.linear1 = (self.linear0 if shared else nn.Linear(input_dims[1], mm_dim)) self.merge_linears0 = nn.LayerList() self.merge_linears1 = nn.LayerList() self.chunks = self.chunk_sizes(mm_dim, chunks) for size in self.chunks: ml0 = nn.Linear(size, size * rank) self.merge_linears0.append(ml0) ml1 = ml0 if shared else nn.Linear(size, size * rank) self.merge_linears1.append(ml1) self.linear_out = nn.Linear(mm_dim, output_dim) def forward(self, x): x0 = self.linear0(x[0]) x1 = self.linear1(x[1]) bs = x1.shape[0] if self.dropout_input > 0: x0 = F.dropout(x0, p=self.dropout_input, training=self.training) x1 = F.dropout(x1, p=self.dropout_input, training=self.training) x0_chunks = paddle.split(x0, self.chunks, -1) x1_chunks = paddle.split(x1, self.chunks, -1) # 这里是转换为低阶的矩阵乘法的关键 zs = [] for x0_c, x1_c, m0, m1 in zip(x0_chunks, x1_chunks, self.merge_linears0, self.merge_linears1): m = m0(x0_c) * m1(x1_c) # bs x split_size*rank m = m.reshape([bs, self.rank, -1]) z = paddle.sum(m, 1) if self.pos_norm == 'before_cat': z = paddle.sqrt(F.relu(z)) - paddle.sqrt(F.relu(-z)) z = F.normalize(z) zs.append(z) z = paddle.concat(zs, 1) if self.pos_norm == 'after_cat': z = paddle.sqrt(F.relu(z)) - paddle.sqrt(F.relu(-z)) z = F.normalize(z) if self.dropout_pre_lin > 0: z = F.dropout(z, p=self.dropout_pre_lin, training=self.training) z = self.linear_out(z) if self.dropout_output > 0: z = F.dropout(z, p=self.dropout_output, training=self.training) return z def chunk_sizes(self, dim, chunks): split_size = (dim + chunks - 1) // chunks sizes_list = [split_size] * chunks sizes_list[-1] = sizes_list[-1] - (sum(sizes_list) - dim) return sizes_list
文本节点与边权重获得部分
- 该模块主要将文本框信息与文本信息之间构建图的关系,便于基于位置来推理文本之间的关系。
- 论文中介绍的可是相当清楚,同时结合源码对应来看,自己就会有豁然开朗的感觉。
- 先说论文中的公式介绍部分,主要集中在论文的公式(5~14)
- 论文中将文档图像作为一个图来看待, g = ( N , E ) \mathcal{g} = (\mathcal{N}, \mathcal{E}) g=(N,E),其中 N = n i \mathcal{N} = {n_{i}} N=ni, n i n_{i} ni是由text node r i r_{i} ri的特征向量组成, E = e i j \mathcal{E} = e_{ij} E=eij, e i j e_{ij} eij是 r i r_{i} ri和 r j r_{j} rj之间的连接权重。
- text node之间的spatial relation的计算主要由论文中公式(5-9)给出:
- 该部分对应源码为:
def compute_relation(self, boxes): """Compute relation between every two boxes.""" # 公式(5) x1s, y1s = boxes[:, 0:1], boxes[:, 1:2] # 公式(6) x2s, y2s = boxes[:, 4:5], boxes[:, 5:6] # 公式(7、8、9) ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1) dxs = (x1s[:, 0][None] - x1s) / self.norm dys = (y1s[:, 0][None] - y1s) / self.norm xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs whs = ws / hs + np.zeros_like(xhhs) relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1) bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32) return relations, bboxes
- 随后,将文本节点之间的信息嵌入到边的权重之中,具体按照下面公式:-
- 该部分对应源码主要位于GNNLayer类中。
图推理模块
论文中提到该部分主要是以迭代的方式来逐步优化节点特征,详细参见论文中公式(13~14):
n
i
l
+
1
=
n
i
l
+
σ
(
W
l
(
∑
j
≠
i
α
i
j
l
e
i
j
l
)
)
α
i
j
l
=
e
x
p
(
e
i
j
)
∑
k
≠
i
e
x
p
(
e
i
k
)
\begin{aligned} n_{i}^{l+1} &= n_{i}^{l} + \sigma(W^{l}(\sum \limits_{ j\neq{i} } \alpha_{ij}^{l} \bm{e}_{ij}^{l})) \\ \alpha_{ij}^{l} &= \frac{exp(e_{ij})}{\sum \limits_{k\neq{i}} exp(e_{ik})} \end{aligned}
nil+1αijl=nil+σ(Wl(j=i∑αijleijl))=k=i∑exp(eik)exp(eij)
- 对应源码部分如下:
class GNNLayer(nn.Layer): def __init__(self, node_dim=256, edge_dim=256): super().__init__() self.in_fc = nn.Linear(node_dim * 2 + edge_dim, node_dim) self.coef_fc = nn.Linear(node_dim, 1) self.out_fc = nn.Linear(node_dim, node_dim) self.relu = nn.ReLU() def forward(self, nodes, edges, nums): # 合并节点信息和边信息 start, cat_nodes = 0, [] for num in nums: sample_nodes = nodes[start:start + num] cat_nodes.append( paddle.concat([ paddle.expand(sample_nodes.unsqueeze(1), [-1, num, -1]), paddle.expand(sample_nodes.unsqueeze(0), [num, -1, -1]) ], -1).reshape([num**2, -1])) start += num cat_nodes = paddle.concat([paddle.concat(cat_nodes), edges], -1) # 公式(13) cat_nodes = self.relu(self.in_fc(cat_nodes)) coefs = self.coef_fc(cat_nodes) # 公式(14) start, residuals = 0, [] for num in nums: residual = F.softmax( -paddle.eye(num).unsqueeze(-1) * 1e9 + coefs[start:start + num**2].reshape([num, num, -1]), 1) residuals.append((residual * cat_nodes[start:start + num**2] .reshape([num, num, -1])).sum(1)) start += num**2 nodes += self.relu(self.out_fc(paddle.concat(residuals))) return [nodes, cat_nodes]
分类模块
该部分就是两个FC层即可,一个FC对应节点,一个FC对应边。源码
总结
从落地的角度来看这篇工作,似乎并不能很好的使用。我能想到的原因主要有两个:
1. 论文中用的backbone有些厚重,在速度和精度之间没有做到一个trade-off
2. 训练这样一个模型,所需要的数据集并不容易获得,对数据集的要求较高。同时,考虑到大部分应用场景中,中英文居多。但目前并没有相关的中文数据集