基于图片中的表格检测与识别

1、项目介绍

本文将会使用Microsoft开源的表格检测模型table-transformer-detection来实现表格检测与入门。
以下将分三部分进行介绍:

  • 表格检测:检测图片或PDF文件中的表格所在的区域
  • 表格结构识别:对于检测后的表格区域,再详细识别表格的区域,即表格-的行、列,表头所在的位置,进一步得到单元格的位置
  • 表格数据提取: 在表格结构的基础上,借助OCR可得到每个单元格内的文本,从而获得整个表格数据

2、环境构建

2.1、服务配置

在这里插入图片描述

2.2、环境构建

conda create -n Microsoft python==3.8 pip==2.1.1
conda activate Microsoft 
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

模型下载

https://huggingface.co/microsoft/table-transformer-detection/tree/main

3、表格检测

检测图片或PDF文件中的表格所在的区域
部分代码如下:

from transformers import AutoImageProcessor, TableTransformerForObjectDetection
import torch
from PIL import Image

file_path = "./images/demo.jpg"
image = Image.open(file_path).convert("RGB")
file_name = file_path.split('/')[-1].split('.')[0]
......
inputs = image_processor(images=image, return_tensors="pt")
outputs = model(**inputs)

# convert outputs (bounding boxes and class logits) to COCO API
target_sizes = torch.tensor([image.size[::-1]])
results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[0]

i = 0
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
    box = [round(i, 2) for i in box.tolist()]
    print(
        f"Detected {model.config.id2label[label.item()]} with confidence "
        f"{round(score.item(), 3)} at location {box}"
    )

    region = image.crop(box)  # 检测
    region.save(f'./images/{file_name}_{i}.png')
    i += 1

结果如下:

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
The `max_size` parameter is deprecated and will be removed in v4.26. Please specify in `size['longest_edge'] instead`.
Detected table with confidence 0.998 at location [429.76, 217.22, 730.43, 441.51]
Detected table with confidence 0.997 at location [89.02, 215.52, 372.13, 440.98]

4、表格结构识别

部分代码如下:

from transformers import DetrFeatureExtractor, TableTransformerForObjectDetection
import torch
from PIL import Image



#使用 DetrFeatureExtractor
feature_extractor = DetrFeatureExtractor()

file_path = "./images/demo_1.png"
image = Image.open(file_path).convert("RGB")

# 对图像进行编码处理
encoding = feature_extractor(images=image, return_tensors="pt")
......
# 前向推理
with torch.no_grad():
    outputs = model(**encoding)

target_sizes = [image.size[::-1]]
results = feature_extractor.post_process_object_detection(outputs, threshold=0.6, target_sizes=target_sizes)[0]
print(results)

columns_box_list = [results['boxes'][i].tolist() for i in range(len(results['boxes'])) if results['labels'][i].item()==3]
for idx, box in enumerate(columns_box_list):
    print(idx)
    crop_image = image.crop(box)
    crop_image.save(f'header_{idx}.png')

结果如下:

{0: 'table', 1: 'table column', 2: 'table row', 3: 'table column header', 4: 'table projected row header', 5: 'table spanning cell'}
{'scores': tensor([0.9938, 0.9916, 0.9990, 0.9951, 0.9988, 0.9951, 0.9914, 0.9138, 0.9976,
        0.9996]), 'labels': tensor([2, 2, 1, 2, 1, 1, 2, 3, 1, 0]), 'boxes': tensor([[ 17.4817, 111.3828, 246.3970, 153.2920],
        [ 17.5167,  66.2340, 246.3717, 105.3657],
        [ 17.5555,  34.3023,  66.3797, 183.1057],
        [ 17.4001,  33.6988, 246.5164,  65.5520],
        [146.3985,  34.1393, 225.7526, 182.9028],
        [226.7389,  34.1763, 247.0905, 183.1046],
        [ 17.4087, 151.3451, 246.1476, 183.2057],
        [ 17.1707,  33.7522, 246.4425,  64.2807],
        [ 67.4367,  33.9697, 146.2301, 183.5033],
        [ 17.5343,  34.2323, 246.6799, 183.0271]])}

5、表格数据提取

部分代码如下:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
......
def paddle_ocr(image_path):
    result = ocr.ocr(image_path, cls=True)
    ocr_result = []
    for idx in range(len(result)):
        res = result[idx]
        if res:
            for line in res:
                print(line)
                ocr_result.append(line[1][0])
    return "".join(ocr_result)


def table_detect(image_box, image_url):
    if not image_url:
        file_name = str(uuid4())
        image = Image.fromarray(image_box).convert('RGB')
    else:
        image_path = f"./images/{uuid4()}.png"
        file_name = image_path.split('/')[-1].split('.')[0]
        urlretrieve(image_url, image_path)
        image = Image.open(image_path).convert('RGB')
    inputs = image_processor(images=image, return_tensors="pt")
    outputs = detect_model(**inputs)
    # convert outputs (bounding boxes and class logits) to COCO API
    target_sizes = torch.tensor([image.size[::-1]])
    results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[0]

    i = 0
    output_images = []
    for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
        box = [round(i, 2) for i in box.tolist()]
        print(
            f"Detected {detect_model.config.id2label[label.item()]} with confidence "
            f"{round(score.item(), 3)} at location {box}"
        )

        region = image.crop(box)  # 检测
        output_image_path = f'./images/{file_name}_{i}.jpg'
        region.save(output_image_path)
        output_images.append(output_image_path)
        i += 1
    print(f"output_images:{output_images}")
    return output_images


def table_ocr(output_images):
 # Debugging line to check the contents of output_images
    print(f"Type of output_images: {type(output_images)}, Contents: {output_images}")
    
    # Assuming the first element of the list contains the image path.
    # Let's check the type of the first element to make sure it's a string.
    if len(output_images) > 0:
        first_image = output_images[0][0]
        print(f"Type of first_image: {type(first_image)}, Contents: {first_image}")
    
    # If it prints out that `first_image` is indeed a string, then you can proceed
    # with opening the image as you were doing:
    image = Image.open(first_image).convert("RGB")

    #image = Image.open(output_image_path).convert("RGB")
    encoding = feature_extractor(image, return_tensors="pt")
    with torch.no_grad():
        outputs = structure_model(**encoding)

    target_sizes = [image.size[::-1]]
    results = feature_extractor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[0]
    print(f"results: {results}\n")
    # get column and row
    columns = []
    rows = []
    for i in range(len(results['boxes'])):
        _id = results['labels'][i].item()
        if _id == 1:##-----列内容
            columns.append(results['boxes'][i].tolist())
        elif _id == 2:##-----行内容
            rows.append(results['boxes'][i].tolist())

    sorted_columns = sorted(columns, key=lambda x: x[0])
    sorted_rows = sorted(rows, key=lambda x: x[1])
    # ocr by cell
    ocr_results = []
    for row in sorted_rows:
        row_result = []
        for col in sorted_columns:
            rect = [col[0], row[1], col[2], row[3]]
            crop_image = image.crop(rect)
            image_path = 'cell.png'
            crop_image.save(image_path)
            row_result.append(paddle_ocr(image_path=image_path))
        print(f"row_result: {row_result}\n")
        ocr_results.append(row_result)

    print(f"ocr_results: {ocr_results}\n")
    return ocr_results


if __name__ == '__main__':
    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column():
                image_box = gr.Image()
                image_urls = gr.TextArea(lines=1, placeholder="Enter image url", label="Images")
                # image_index = gr.TextArea(lines=1, placeholder="Image Number", label="No")
            with gr.Column():
                gallery = gr.Gallery(label="Tables", show_label=False, elem_id="gallery", columns=[3], rows=[1],
                                     object_fit="contain", height="auto")
                detect = gr.Button("Table Detection")
                submit = gr.Button("Table OCR")
                ocr_outputs = gr.DataFrame(label='Table',
                                           interactive=True,
                                           wrap=True)
        detect.click(fn=table_detect,
                     inputs=[image_box, image_urls],
                     outputs=gallery)
        submit.click(fn=table_ocr,
                     inputs=[gallery],
                     outputs=ocr_outputs)
    demo.launch(server_name="0.0.0.0", server_port=7676, share=True)

结果如下:
在这里插入图片描述

6、总结

表格识别结果,第一列和最后一列有时候识别不出。问题原因可能表格结构识别存在问题,后续会继续优化。

  • 10
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值