夯实基础系列:表格识别算法(SLANet)核心代码

SLANet是PaddleOCR中用于表格结构识别的模型,它结合PP-LCNet、CSP-PAN和Attention机制。模型预测结构信息和单元格坐标,分别使用交叉熵和Smooth-L1损失函数。核心部分包括AttentionGRUCell和结构、位置生成器,用于解码预测。在非训练模式下,结构信息会进行softmax处理。SLALoss函数结合了结构和位置预测的损失。
摘要由CSDN通过智能技术生成
引言
基本原理
  • SLANet主要是结合PP-LCNet + CSP-PAN + Attention来完成的。
  • 模型预测两个值,一是structure_pobs,表格结构的html代码,二是loc_preds,回归单元格四个点坐标。
  • 预测表格结构部分损失函数采用的是交叉熵损失,预测单元格坐标损失函数采用的是smooth-l1-loss。
整体框架

在这里插入图片描述

核心代码
  • 可直接复制运行,该代码主要用来查看算法核心部分的实现细节
  • 所用数据皆为randn得来,请侧重查看各个阶段数据的shape,来知晓其中做法。
  • 每个算法的核心部分并不神秘,只是大部分情况都被包裹的太深了而已。
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import torch
from torch import nn
from torch.nn import functional as F


class SLAHead(nn.Module):
    def __init__(self, in_channels=96, is_train=False) -> None:
        super().__init__()
        self.max_text_length = 500
        self.hidden_size = 256
        self.loc_reg_num = 4
        self.out_channels = 30
        self.num_embeddings = self.out_channels
        self.is_train = is_train

        self.structure_attention_cell = AttentionGRUCell(in_channels,
                                                         self.hidden_size,
                                                         self.num_embeddings)

        self.structure_generator = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.Linear(self.hidden_size, self.out_channels)
        )

        self.loc_generator = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.Linear(self.hidden_size, self.loc_reg_num)
        )

    def forward(self, fea):
        batch_size = fea.shape[0]

        # 1 x 96 x 16 x 16 → 1 x 96 x 256
        fea = torch.reshape(fea, [fea.shape[0], fea.shape[1], -1])

        # 1 x 256 x 96
        fea = fea.permute(0, 2, 1)

        # infer 1 x 501 x 30
        structure_preds = torch.zeros(batch_size, self.max_text_length + 1,
                                      self.num_embeddings)
        # 1 x 501 x 4
        loc_preds = torch.zeros(batch_size, self.max_text_length + 1,
                                self.loc_reg_num)

        hidden = torch.zeros(batch_size, self.hidden_size)
        pre_chars = torch.zeros(batch_size, dtype=torch.int64)

        loc_step, structure_step = None, None
        for i in range(self.max_text_length + 1):
            hidden, structure_step, loc_step = self._decode(pre_chars,
                                                            fea, hidden)
            pre_chars = structure_step.argmax(dim=1)
            structure_preds[:, i, :] = structure_step
            loc_preds[:, i, :] = loc_step

        if not self.is_train:
            structure_preds = F.softmax(structure_preds, dim=-1)
        # structure_preds: 1 x 501 x 30
        # loc_preds: 1 x 501 x 4
        return structure_preds, loc_preds

    def _decode(self, pre_chars, features, hidden):
        emb_features = F.one_hot(pre_chars, num_classes=self.num_embeddings)
        (output, hidden), alpha = self.structure_attention_cell(hidden,
                                                                features,
                                                                emb_features)
        structure_step = self.structure_generator(output)
        loc_step = self.loc_generator(output)
        return hidden, structure_step, loc_step


class AttentionGRUCell(nn.Module):
    def __init__(self, input_size, hidden_size, num_embedding) -> None:
        super().__init__()

        self.i2h = nn.Linear(input_size, hidden_size, bias=False)
        self.h2h = nn.Linear(hidden_size, hidden_size)
        self.score = nn.Linear(hidden_size, 1, bias=False)

        self.gru = nn.GRU(input_size=input_size + num_embedding,
                          hidden_size=hidden_size,)
        self.hidden_size = hidden_size

    def forward(self, prev_hidden, batch_H, char_onehots):
        # 这里实现参考论文https://arxiv.org/pdf/1704.03549.pdf
        batch_H_proj = self.i2h(batch_H)
        prev_hidden_proj = torch.unsqueeze(self.h2h(prev_hidden), dim=1)

        res = torch.add(batch_H_proj, prev_hidden_proj)
        res = F.tanh(res)
        e = self.score(res)

        alpha = F.softmax(e, dim=1)
        alpha = alpha.permute(0, 2, 1)
        context = torch.squeeze(torch.matmul(alpha, batch_H), dim=1)
        concat_context = torch.concat([context, char_onehots], 1)

        cur_hidden = self.gru(concat_context, prev_hidden)
        return cur_hidden, alpha


class SLALoss(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.loss_func = nn.CrossEntropyLoss()
        self.structure_weight = 1.0
        self.loc_weight = 2.0
        self.eps = 1e-12

    def forward(self, pred):
        structure_probs = pred[0]
        structure_probs = structure_probs.permute(0, 2, 1)
        # 1 x 30 x 501

        # 1 x 501
        structure_target = torch.empty(1, 501, dtype=torch.long).random_(30)
        structure_loss = self.loss_func(structure_probs, structure_target)
        structure_loss = structure_loss * self.structure_weight

        loc_preds = pred[1]  # 1 x 501 x 4
        loc_targets = torch.randn(1, 501, 4)
        loc_target_mask = torch.randn(1, 501, 1)

        loc_loss = F.smooth_l1_loss(loc_preds * loc_target_mask,
                                    loc_targets * loc_target_mask,
                                    reduction='mean')
        loc_loss *= self.loc_weight
        loc_loss = loc_loss / (loc_target_mask.sum() + self.eps)

        total_loss = structure_loss + loc_loss
        return total_loss


if __name__ == '__main__':
    # x 输入为 1x3x488x488
    # 经过LCNet之后,记录四部分的输出值,各自Shape如下:
    # 1 x 64 x 122 x 122
    # 1 x 128 x 61 x 61
    # 1 x 256 x 31 x 31
    # 1 x 512 x 16 x 16

    # 经过CSPLayer + PAN结构,包括(top-down和down-top两个结构),各自shape如下:
    # 1 x 96 x 122 x 122
    # 1 x 96 x 61 x 61
    # 1 x 96 x 31 x 31
    # 1 x 96 x 16 x 16

    # 取 1 x 96 x 16 x 16这个特征,进入后续基于注意力机制的层
    x = torch.randn(1, 96, 16, 16)

    model = SLAHead()
    loss_func = SLALoss()

    y = model(x)

    loss = loss_func(y)

    print(y[0].shape)
    print(y[1].shape)

表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别示例代码表格识别
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值