Donut模型-图像文本阅读以及下游任务的多模态大模型

目录

一、简单介绍

二、Donut的结构

三、预训练的任务

3.1 任务描述

3.2 预训练任务数据集

四、下游任务

4.1 分类任务

4.1.1 任务描述

4.1.2 任务数据集

4.1.2.1 RVL-CDIP数据集

4.2 文档信息抽取任务

4.2.1 任务描述

4.2.2任务数据集

4.2.2.1 CORD数据集

4.2.2.2 Train Ticket数据集

4.2.2.3 名片数据集

4.2.2.4 票据数据

4.2.3 评估指标

4.3  文档问答

五、代码相关

五、实际应用的效果展示

预训练任务效果

火车票提取任务


一、简单介绍

OCR-free Document Understanding Transformer

github地址 GitHub - clovaai/donut: Official Implementation of OCR-free Document Understanding Transformer (Donut) and Synthetic Document Generator (SynthDoG), ECCV 2022

hugging face上的基础模型地址 naver-clova-ix/donut-base · Hugging Face

参考的知乎专栏

Donut:不用OCR中间过程也能理解图片文档 - 知乎

传统的文档理解Visual Document Understanding (VDU) 的步骤如下

简单概括就是:

1. 一个OCR文本提取器,提取到文本信息,文本框坐标信息;

2. 将原始图片编码得到编码后的图片embedding,将文本序列编码得到token sequence embedding,将文本框的位置信息归一化之后进行位置编码得到position embedding;

3. 通过transformer的架构将这些embedding结合起来进行下游任务的训练。

donut相比于传统的方式,直接用给一个E2E的模型取代解耦的OCR+下游语言模型。时间上更快,准确率也更高。

二、Donut的结构

结构很简单,一个视觉编码器和一个NLP的语言解码器构成。编码器和解码器都是基于Transformer的。在本文中作者采用swin-Transformer作为视觉的编码器,(swin_transformer代码地址https://github.com/huggingface/pytorch-image-models/blob/v0.6.13/timm/models/swin_transformer.py)因为效果最好。

编码器和解码器:

编码器。视觉编码器将输入文档图像x∈RH×W×C转换为一组嵌入向量{zi|zi∈Rd,1≤i≤n},其中n是特征图大小或图像块数量,d是编码器的潜在向量的维度。请注意,CNN基础模型[17]或Transformer基础模型[9,40]可以作为编码器网络使用。在本研究中,我们使用Swin Transformer [40],因为在我们的文档解析初步研究中,它表现出最佳的性能。Swin Transformer首先将输入图像x切分成非重叠的块。Swin Transformer块,由一个移位窗口为基础的多头自注意力模块和一个两层的MLP组成,应用于这些块。然后,在每个阶段都应用了块合并层到块令牌上。最后的Swin Transformer块的输出{z}被输入到下一个文本解码器中。

解码器。给定{z},文本解码器生成一个令牌序列(yi)mi=1,其中yi∈Rv是第i个令牌的独热向量,v是令牌词汇表的大小,m是超参数。我们使用BART [33]作为解码器架构。具体来说,我们使用公开可用的预训练的多语言BART模型[38]的权重来初始化解码器模型权重。

(编码器和解码器的介绍完全来自知乎专栏——Donut:不用OCR中间过程也能理解图片文档 - 知乎  )

三、预训练的任务

3.1 任务描述

模型需要学习识别图片上的文字内容,并且是按照人类的阅读习惯(从左到右,从上到下),训练的对象是每个输出token,训练的条件是参考图片以及前面时刻的输出内容,训练目标是最小化交叉熵损失函数。这个任务可以看成给一个伪OCR任务。

当编码器学习完成之后,可以得到图片的视觉编码信息为sequence len为H*W,dim=C的embedding,其中的HW为进经过SWIN Transformer经过16倍下采样之后的长宽(一般是4层下采样,但是不排除其他配置的可能)。将这个视觉编码信息作为解码器的attention层,通过解码器微调学习文字的理解。在本文中,将所有的预训练任务和下游任务(分类,问答,结构化)都转换为一个json序列的预测问题。比如对于分类问题,模型被训练输出[START class][memo][END class],这个可以在后面被转换为 JSON {“class”:“memo”}.

预训练采用teacher-forcing scheme策略。这种策略简单的说,就是采用GT作为前面时刻的输出,也就是当前时刻的输入。而不是采用模型预测的前面时刻的输出。在测试阶段,模型通过一个给定的prompt输出一个token序列。在实验中,对于每个特定的下游任务,采用特殊的新的token作为prompt。具体的token可以参考上面的图3.

图E的左上角介绍了teaching-force的训练策略,右上角是推理阶段。可以看到推理阶段,t时刻的输入是t-1时刻的输出。

右下角介绍了输出格式,然后通过某种特定的转换,变为json的格式(左下角)。具体是怎么转换的,可以用正则表达式来编写(源代码中的方式)。

3.2 预训练任务数据集

IIT-CDIP数据集(据说很大),包含11million英语的文档图片。一个商业应用l CLOVA OCR API可以用于生成虚假的文本标签。作者还提供了一个自制的开源应用SynthDoG,用于生成中文、日文、韩文以及英文,作者分别为这四种语言各自生成0.5Million的数据。

四、下游任务

4.1 分类任务

4.1.1 任务描述

判断文档图片属于 哪个类别的能力。

4.1.2 任务数据集

4.1.2.1 RVL-CDIP数据集

这是英文文档的分类任务,图片为二值化图片,非彩色图片;类别有16类,分别如下所示

数据集介绍地址 https://adamharley.com/rvl-cdip/​​​​​

 数据集信息:40万张图片,每个类别2.5万张,32万用于训练,4万用于验证,4万有用测试;每张图片的最大尺寸(H或W)不超过1000

4.2 文档信息抽取任务

4.2.1 任务描述

document information extraction (IE)

采用结构化的方式构建文档信息;需要模型可以同时阅读文本内容、理解文本、文本的layout结构以及语义信息来总结归纳整个文档信息。

之前有的部分工作只需要抽取一些预先定义的关键字信息即可,那些任务不需要理解文本的layout结构,所以会比较简单一点,列举在这里。

4.2.2任务数据集

4.2.2.1 CORD数据集

数据集地址:

naver-clova-ix/cord-v1 · Datasets at Hugging Face

菜单或者账单的抽取,主要是我们去餐馆的小票。800张训练图片,100张验证集图片,100张测试集图片。在官方给出的代码中,task_name=cord-v2。

图片的GT(拷贝了一部分,一张图片上的GT太长了)

"{"gt_parse": {"menu": [{"nm": "HAKAU UDANG", "cnt": "4", "price": "92,000"}, {"nm": "SIAO MAI BABI", "cnt": "4", "price": "80,000"}, {"nm": "CEKER AYAM", "cnt": "3", "price": "60,000"}, {"nm": "BAKPAO BKR C CRISPY", "cnt": "2", "price": "42,000"}, {"nm": "TAHU GORENG CRISPY", "cnt": "3", "price": "60,000"}], "sub_total": {"subtotal_price": "334,000"}, "total": {"total_price": "334,000", "cashprice": "350,000", "changeprice": "-16,000", "menutype_cnt": "5", "menuqty_cnt": "16"}}, "meta": {"version": "1.0.0", "split": "train", "image_id": 2, "image_size": {"width": 720, "height": 1280}}, "valid_line": [{"words": [{"quad": {"x2": 272, "y3": 489, "x3": 270, "y4": 489, "x1": 174, "y1": 461, "x4": 174, "y2": 459}, "is_key": 0, "row_id": 539268, "text": "HAKAU"}, {"quad": {"x2": 379, "y3": 488, "x3": 380, "y4": 488, "x1": 280, "y1": 460, "x4": 278, "y2": 457}, "is_key": 0, "row_id": 539268, "text": "UDANG"}], "category": "menu.nm", "group_id": 3}, {"words": [{"quad": {"x2": 166, "y3": 488, "x3": 166, "y4": 488, "x1": 149, "y1": 463, "x4": 149, "y2": 463}, "is_key": 0, "row_id": 539268, "text": "4"}], "category": "menu.cnt", "group_id": 3}, {"words": [{"quad": {"x2": 627, "y3": 483, "x3": 629, "y4": 485, "x1": 531, "y1": 453, "x4": 533, "y2": 449}, "is_key": 0, "row_id": 539268, "text": "92,000"}], "category": "menu.price", "group_id": 3}, {"words": [{"quad": {"x2": 244, "y3": 521, "x3": 244, "y4": 521, "x1": 174, "y1": 496, "x4": 170, "y2": 494}, "is_key": 0, "row_id": 539269, "text": "SIAO"}, {"quad": {"x2": 303, "y3": 521, "x3": 305, "y4": 522, "x1": 252, "y1": 495, "x4": 250, "y2": 495}, "is_key": 0, "row_id": 539269, "text": "MAI"}, {"quad": {"x2": 379, "y3": 521, "x3": 377, "y4": 520, "x1": 313, "y1": 494, "x4": 311, "y2": 493}, "is_key": 0, "row_id": 539269, "text": "BABI"}], "category": "menu.nm", "group_id": 4}, {"words": [{"quad": {"x2": 166, "y3": 525, "x3": 166, "y4": 525, "x1": 147, "y1": 498, "x4": 147, "y2": 498}, "is_key": 0, "row_id": 539269, "text": "4"}], "category": "menu.cnt", "group_id": 4}, {"words": [{"quad": {"x2": 630, "y3": 518, "x3": 627, "y4": 518, "x1": 532, "y1": 487, "x4": 532, "y2": 484}, "is_key": 0, "row_id": 539269, "text": "80,000"}], 

4.2.2.2 Train Ticket数据集

中国蓝色的火车票信息抽取任务,数据源未开源。数据集规模为1500训练图片和400张测试图片。抽取其中的票据号码、起始站、列车编号、到达站、用户名等等8个信息。

task_name=zhtrainticket

4.2.2.3 名片数据集

同样未开源,数据集规模为20000张训练数据,300张测试数据,300张验证数据。任务和火车票相似。

4.2.2.4 票据数据

同样未开源,包括4万训练集,1K测试集,1K验证集。任务和火车票相似,但是每张图片中可以抽取的信息entity数量高达81条。

4.2.3 评估指标

1. F1指标

2. TED指标

4.3  文档问答

用于探索图像文档理解的潜力而设置的任务(一般这么说的话,任务的表现都不会很好......).

编码层的输入为图片,解码层的输入为Question的序列化表示;输出回答的序列表示。

数据集为DocVQA;图片文档数量为12000,并且包含5万个问答;

数据集地址 Overview - Document Visual Question Answering - Robust Reading Competition (Document Visual Question Answering),这个网站是一个包含多个任务的挑战赛页面。作者这里参与的是task1。

 评估指标采用ANLS (Average Normalized Levenshtein Similarity) 这是一个基于编辑距离的指标。实际的score是通过将结果上传到任务网站,由任务网站实现的评测。

五、代码相关

github代码位置 GitHub - clovaai/donut: Official Implementation of OCR-free Document Understanding Transformer (Donut) and Synthetic Document Generator (SynthDoG), ECCV 2022

采用hugging face的代码封装 https://github.com/huggingface/transformers

https://github.com/rwightman/pytorch-image-models

训练阶段采用半精度训练,优化器为Adam,预训练的初始学习率为1e-4,微调阶段的学习率为1e-5到1e-4之间。预训练了200Ksteps,batchsize=196;在64张A100上进行。采用了梯度截断策略,输入图片的尺寸为2560*1920.

在微调阶段,输入图片的尺寸为1280*960;并且训练时间也减少了。在Cord和ticket数据上微调花费0.5hours,单卡A100

在CDIP数据上或者在问答数据集上,依旧采用2560*1920的分辨输入;在64张A100上DocVQA数据集训练了1天,CDIP数据集开销了2天。

实际使用的代码——火车票检测下游任务

!pip install transformers==4.25.1
!pip install pytorch-lightning==1.6.4
!pip install timm==0.5.4
!pip install gradio
!pip install donut-python

import argparse
import gradio as gr
import torch
from PIL import Image

from donut import DonutModel
def demo_process_vqa(input_img, question):
    global pretrained_model, task_prompt, task_name
    input_img = Image.fromarray(input_img)
    user_prompt = task_prompt.replace("{user_input}", question)
    output = pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0]
    return output


def demo_process(input_img):
    global pretrained_model, task_prompt, task_name
    input_img = Image.fromarray(input_img)
    output = pretrained_model.inference(image=input_img, prompt=task_prompt)["predictions"][0]
    return output
    
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="zhtrainticket")
parser.add_argument("--pretrained_path", type=str, default="naver-clova-ix/donut-base-finetuned-zhtrainticket")
args, left_argv = parser.parse_known_args()

task_name = args.task
if "docvqa" == task_name:
    task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
else:  # rvlcdip, cord, ...
    task_prompt = f"<s_{task_name}>"

pretrained_model = DonutModel.from_pretrained(args.pretrained_path)

if torch.cuda.is_available():
    pretrained_model.half()
    device = torch.device("cuda")
    pretrained_model.to(device)
else:
    pretrained_model.encoder.to(torch.bfloat16)

pretrained_model.eval()

demo = gr.Interface(
    fn=demo_process_vqa if task_name == "docvqa" else demo_process,
    inputs=["image", "text"] if task_name == "docvqa" else "image",
    outputs="json",
    title=f"Donut 🍩 demonstration for `{task_name}` task",
)
demo.launch()

五、实际应用的效果展示

预训练任务效果

taskname=synthdog(如果是中文、日文韩文等非英语)

taskname=iitcdip (如果是英文)

不知道是不是放出来的模型不是最好的模型,测试了营业执照、表格、身份证、普通图片、营业执照以及普通的纯文本。看起来只有在最后一个图片上效果比较好。

 

 

 

火车票提取任务

看看在训练好的特征抽取上的效果(不是很好,没有一张是正确的)

  • 2
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值