1. 项目背景
最近涉及到多模态“OCR” + “DOC-VQA”相关内容,一直使用XFUND数据集,但实际项目中需要训练真实数据才能达到更好的效果,那么如何制作DOC-VQA格式的数据集呢?
先了解一下几个简单的概念:
- 文本检测:定位出输入图像中的文字区域。
- 文本识别:识别出图像中的文字内容,一般输入来自于文本检测得到的文本框截取出的图像文字区域。
- 关键信息提取(Key Information Extraction,KIE)是Document VQA中的一个重要任务,主要从图像中提取所需要的关键信息,如从身份证中提取出姓名和公民身份号码信息,这类信息的种类往往在特定任务下是固定的,但是在不同任务间是不同的。
KIE通常分为两个子任务进行研究
- SER: 语义实体识别 (Semantic Entity Recognition), 可以完成对图像中的文本识别与分类。
- RE: 关系抽取 (Relation Extraction),对每一个检测到的文本进行分类,如将其分为问题和的答案。然后对每一个问题找到对应的答案。基于 RE 任务,可以完成对图象中的文本内容的关系提取,如判断问题对(pair)。

图中红色框表示问题,蓝色框表示答案,问题和答案之间使用绿色线连接。在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
2. XFUND数据集
我们来看一下XFUND数据集是怎么样?首先,要使用XFUND数据集进行训练或验证都需要先转换为“图片路径 JSON字符串”的形式,JSON字符串如:
{
"height": 3508, # 图像高度
"width": 2480, # 图像宽度
"ocr_info": [
{
"text": "邮政地址:", # 单个文本内容
"label": "question", # 文本所属类别
"bbox": [261, 802, 483, 859], # 单个文本框
"id": 54, # 文本索引
"linking": [[54, 60]], # 当前文本和其他文本的关系 [question, answer]
"words": []
},
{
"text": "湖南省怀化市市辖区",
"label": "answer",
"bbox": [487, 810, 862, 859],
"id": 60,
"linking": [[54, 60]],
"words": []
}
]
}
其中的label
一般分为:header
、question
、answer
、other
,其中question
和answer
如果是对应关系则都包含一个相同的linking
。一个question
可以包含多个answer
。
2.1 下载数据集
%cd /home/aistudio/
! mkdir /home/aistudio/XFUND && mkdir /home/aistudio/XFUND/zh_train && mkdir /home/aistudio/XFUND/zh_val
! unzip -q -o /home/aistudio/data/data140302/XFUND_ori.zip -d /home/aistudio/data/data140302/
! mv /home/aistudio/data/data140302/XFUND_ori/zh.train /home/aistudio/XFUND/zh_train/image
! mv /home/aistudio/data/data140302/XFUND_ori/zh.val /home/aistudio/XFUND/zh_val/image
! rm -rf /data/data140302/XFUND_ori
/home/aistudio
mkdir: 无法创建目录"/home/aistudio/XFUND": 文件已存在
2.2 转为可训练的格式
! unzip -q -o PaddleOCR.zip
# 如仍需安装or安装更新,可以执行以下步骤
#! git clone https://gitee.com/PaddlePaddle/PaddleOCR
# 安装依赖包
! pip install -r /home/aistudio/PaddleOCR/requirements.txt > install.log
#! pip install paddleocr >> install.log
# 安装nlp及其他包
# ! pip install yacs gnureadline paddlenlp==2.2.1 >> install.log
# ! pip install xlsxwriter >> install.log
! pip install regex
[33mWARNING: You are using pip version 22.0.4; however, version 22.1.2 is available.
You should consider upgrading via the '/opt/conda/envs/python35-paddle120-env/bin/python -m pip install --upgrade pip' command.[0m[33m
[0mLooking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: regex in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (2022.6.2)
[33mWARNING: You are using pip version 22.0.4; however, version 22.1.2 is available.
You should consider upgrading via the '/opt/conda/envs/python35-paddle120-env/bin/python -m pip install --upgrade pip' command.[0m[33m
[0m
# 还可通过以下命令,生成文本检测和文本识别的训练集和验证集
%cd /home/aistudio/
! python trans_xfund_data.py
/home/aistudio
Corrupt JPEG data: 18 extraneous bytes before marker 0xc4
Corrupt JPEG data: bad Huffman code
Corrupt JPEG data: premature end of data segment
3 自制数据集
3.1 解压图片数据
本文只用了以下数据集中的33张图片进行了标注。
%cd /home/aistudio/XTOWER
# 解压图片,本文只用了33张图片
! unzip -q -o image.zip
# 解压 ”文本标注结果“的图片
! unzip -q -o crop_img.zip
#更多图片可以用以下命令获取
# ! unzip -q -o /home/aistudio/data/data142101/Scan_0012_0004.zip -d /home/aistudio/XTOWER/image
# ! unzip -q -o /home/aistudio/data/data142101/1234.zip -d /home/aistudio/XTOWER/image
/home/aistudio/XTOWER
3.2 利用PPOCRLabel来辅助标注
选择PPOCRLabel的原因:
- 支持自动标注,可节省不少时间。
- 支持kie模式,可以为文本指定一个分类。
启动 PPOCRLabel(kie模式)
PPOCRLabel --lang ch --kie True
需要本地安装PaddleOCR,需要将图片从XTOWER下image目录打包下载到本地,使用PPOCRLabel自动标注全部图片,然后检查标注并可做一些调整。然后在PPOCRLabel上 “导出标记结果”和“导出识别结果”,“导出标记结果”会生成一个Label.txt文件(文字检测用),“导出识别结果”新建一个“crop_img”文件夹用于保存切割的图片以及rec_gt.txt文件(文字识别用)。
把以下文件和文件夹都上传到/home/aistudio/XTOWER目录:
- crop_img
- rec_gt.txt
- Label.txt
3.3 自动关联QA及分割文本
使用“自动标注”后并检查微调识别的框位置,确定后保存的Label.txt
,这个文件包含了图片上文字的位置信息,但还没有建立QA关系,可以用PPOCRLabel编辑“更改box关键字类别”建立RE关系,如: question_1,question_2,answer_1,answer_2,…还要标注出header。
可不可以根据question 自动匹配 answer呢?设置question都不用手动标注,答案是可以的。
我们可以先定义所要标注的表格的结构,如按行来区分,有哪些header,有哪些question? ,哪些是other?甚至我们还能定义部分question的几倍行高(相当于单行)。
注意:每行有多个question放在同一个list里,显示申明的other不会被当成answer去匹配question。
# 定义一个简单文档结构,不同的文档都可以这里定义
documents=[
{
"headers":["《塔类业务交付验收单》"],
"questions":[
["客 户:"],
["需求名称","铁塔名称"],
["运营商区县","铁塔区县"],
["需求单号","站址编码"],
["所属批次","产品单元"],
["站点经度","站点纬度"],
["验收日期:"],
["塔型"],
["机房类型"],
["挂高"],
["场景"],
["运营商共享"],
["建设内容"],
["市电引入费用原值(元)"],
["存在问题及解决办法"],
["铁塔公司验收负责人:","运营商验收负责人:"],
["其他参加验收人员签字:"]
],
"others":[],
"style":{
"铁塔名称":{"max_height":2},
"建设内容":{"max_height":3}
}
},
{
"headers":["泰安电信“一站一案”需求线下确认单"],
"questions":[
["需求名称","铁塔站点名称"],
["建设单位","设计单位"],
["站址编码","所属批次"],
["站点位置情况"],
["产品类型"],
["塔形、天线挂高"],
["供电类型"],
["建设类型"],
["场租类型"],
["电力引入费用预估"],#如果无answer,标记为other
["交付时间要求:","起租计费时间"],
["起租方式说明:"],
["新建/改造立项编码"],
["塔桅"],
["配套"],
["市电"]
],
"others":["精确到小数点后6位","参考标准工期"],
}
]
import json
from PIL import Image
import regex
import math
import copy
def get_word_list(s1):
# 把句子按字分开,中文按字分,英文按单词
regEx = regex.compile('\W]+') # 我们可以使用正则表达式来切分句子,切分的规则是除单词,数字外的任意字符串
res = regex.compile(r"([\u4e00-\u9fa5\pZ\(\)\:\。\,\?\《\》\“\”])") # [\u4e00-\u9fa5]中文范围
p1 = regEx.split(s1)
str1_list = []
for str in p1:
if res.split(str) == None:
str1_list.append(str)
else:
ret = res.split(str)
for ch in ret:
str1_list.append(ch)
# list_word1 = [w for w in str1_list if w in [" "," "] or len(w.strip()) > 0] # 去掉为空的字符
list_word1 = [w for w in str1_list if len(w) > 0]
return list_word1
def isChinese(word):
if '\u4e00' <= word <= '\u9fff':
return True
if word in "()《》“”‘’ 。,:【】「」?": #中文标点符号
return True
elif len(word) ==1 and 32 <= ord(word) <= 255:
return False
elif len(regex.findall('\p{Z}', word)) != 0: #中文符号
return True
return False
"""
分割文字标注框,横向水平均分
box = [[x1,y2],[x2,y2]]
"""
def splitBox(box,words):
left_padding=2
top_padding=0
# words的box的四个点的顺序分便是:上左,上右,下右,下左,只需均分上和下两条直线即可,
# 考虑到有倾斜的情况,垂直方向也需要做均分
split_width=[]
word_count=0
for word in words:
if isChinese(word):
split_width.append(1)
word_count = word_count +1
else:
#大写,小写,数字 占多宽
split_width.append(0.5*len(word)) #非中文一个字符占0.5个中文宽度
word_count = word_count+ (0.5*len(word))
dx = (int(box[1][0]) - int(box[0][0]) - left_padding) / word_count #单个字的宽度
dy = (int(box[2][1]) - int(box[3][1])) / word_count #单个字的高度
wordboxes = []
i=0
bx=left_padding
by=0
for word in words:
x = dx * split_width[i]
if isChinese(word):
px= x * 0.8 # 更容易取中文特征
else:
px= x
y = dy * split_width[i]
make_ocrinfo = {}
make_ocrinfo['transcription']=word
make_ocrinfo['key_cls']='word'
make_ocrinfo['difficult']=False
make_ocrinfo['points']=[
[int(box[0][0]+bx),int(box[0][1]+by+top_padding)],
[int(box[0][0]+bx+px),int(box[0][1]+by+top_padding)],
[int(box[0][0]+bx+px),int(box[3][1]+by+y-top_padding)],
[int(box[0][0]+bx),int(box[3][1]+by+y-top_padding)]
]
# print(box)
wordboxes.append(make_ocrinfo)
# print(make_ocrinfo)
bx=bx+x
by=by+y
i = i+1
return wordboxes
# 根据PPOCRLabel自动标注转格式,无需更改标签类型
def make_qa_linking(labelfile,newlabelfile,documents=None,split_text=False,ignore_word=False):
file = open(labelfile)
newinfo = {}
i = 0
lines = ""
json_lines = ""
while True:
line = file.readline()
if not line:
break
image_path,ocr_info = line.split("\t");
ocr_infos = json.loads(ocr_info)
# 1.找出表格的header,判断是哪张表格
curr_doument =None;
for ocrinfo in ocr_infos:
for document in documents:
if(ocrinfo['transcription'] in document['headers']):
curr_doument = document;
if curr_doument== None:
print(image_path+'表格未匹配')
continue
# 排除word类型
all_word_info=[]
new_ocr_infos=[]
for ocr_info in ocr_infos:
if ocr_info['key_cls'] == "word" :
if ignore_word == False:
all_word_info.append(copy.deepcopy(ocr_info))
else:
new_ocr_infos.append(copy.deepcopy(ocr_info))
for i,ocrinfo in enumerate(new_ocr_infos):
#如果忽略word,则删除word信息
# if ignore_word == True:
# continue
words=[]
# 分析word的从属关系
for word_info in all_word_info:
word_center_x = word_info['points'][0][0] + int((word_info['points'][1][0] - word_info['points'][0][0]) / 2)
word_center_y = word_info['points'][0][1] + int((word_info['points'][3][1] - word_info['points'][0][1]) / 2)
#如果word的中心点落在 question或answer标签内
if ocrinfo['points'][0][0] < word_center_x < ocrinfo['points'][1][0] and ocrinfo['points'][0][1] < word_center_y < ocrinfo['points'][3][1] :
words.append(word_info)
ocrinfo['words'] = words
# 2. 定位question,并修正首位符号为全角。因PPOCRLabel将首位中文符号“(、)、《、》、:“等识别为了半角的宽度
questions= {}
q=1
ocr_infos=copy.deepcopy(new_ocr_infos)
for ocrinfo in ocr_infos:
# **********
# 本段是根据ppocrlabel自动识别的情况,在文字前后出现特殊中文符号时增加标注的长度,方便切割文字。
# 当QA关系通过key_cls建立后,就不再调整特殊符号的标注框位置
#
if 'key_cls' not in ocrinfo or ocrinfo['key_cls'] == 'None':
if ocrinfo['transcription'] == "客户:":
ocrinfo['transcription'] = "客 户:"
if ocrinfo['transcription'][0] in ["《","("]:
ocrinfo['points'][0][0] = ocrinfo['points'][0][0] -20
ocrinfo['points'][3][0] = ocrinfo['points'][3][0] -20
if ocrinfo['transcription'][-1] in ['》',')',':','。',',']:
ocrinfo['points'][1][0] = ocrinfo['points'][1][0] +20
ocrinfo['points'][2][0] = ocrinfo['points'][2][0] +20
# ********* end
row=1
for row_questions in curr_doument['questions']:
# print(row_questions)
col=1
col_count = len(row_questions)
for one_quesion in row_questions:
if ocrinfo['transcription'] == one_quesion:
ocrinfo['key_cls'] = 'question_'+str(q)
ocrinfo['construct'] = [row,col,col_count]
questions['q_'+str(row)+'_'+str(col)]=ocrinfo
q = q+1
col = col +1
row = row+1
# 2.根据question 找出对应的answer(多行的话,暂时用表格多个答案,再分割word后再合并),other,header 保留
for ocrinfo in ocr_infos:
for bianhao in questions:
question = questions[bianhao]
question_points = question['points']
[q_row,q_col,q_col_count] = question['construct']
one_question_answers=[]
label_points = ocrinfo['points']
# 排除其他的label。满足:左边框在问题右边框的右边;下边框在问题上边框的下边;上边框在问题下边框的上边
if ocrinfo['transcription'] in curr_doument['headers'] :
ocrinfo['key_cls'] = "header"
elif ocrinfo['key_cls'] == "word":
continue
elif "words" in ocrinfo and len(ocrinfo['words']) > 0:
continue
else:
max_height = 1
style = curr_doument['questions']
if question['transcription'] in style:
if "max_height" in style[question['transcription']]:
max_height = style[question['transcription']]["max_height"]
#额外增加answer的判定范围
q_lineheight = abs(question_points[3][1]-question_points[0][1]) * (max_height-1)/2
if label_points[0][0] > question_points[1][0] and ( label_points[3][1] > question_points[1][1]-q_lineheight and label_points[0][1] < question_points[2][1] + q_lineheight ):
if q_col < q_col_count:
next_col=q_col+1
bh='q_'+str(q_row)+'_'+str(next_col)
if bh in questions:
next_question= questions[bh]
if label_points[1][0] < next_question['points'][0][0]:
answer_name= copy.deepcopy(question['key_cls']).replace('question_','answer_')
ocrinfo['key_cls'] = answer_name #答案辅助标签,与question后的数字对应
one_question_answers.append(copy.deepcopy(ocrinfo))
elif q_col == q_col_count:
answer_name= copy.deepcopy(question['key_cls']).replace('question_','answer_')
ocrinfo['key_cls'] = answer_name #答案辅助标签,与question后的数字对应
one_question_answers.append(copy.deepcopy(ocrinfo))
question['answers'] = one_question_answers
# 3.切割文字生成新的标注框question
if split_text == True:
for ocrinfo in ocr_infos:
# 已分割的不再继续分割,方便手动微调。
if ocrinfo['key_cls'] == "word":
continue
if "words" in ocrinfo and len(ocrinfo['words']) > 0:
continue
qwords = get_word_list(ocrinfo['transcription'])
qwordboxes = splitBox(ocrinfo['points'],qwords)
ocrinfo['words']=qwordboxes
if "answers" in ocrinfo:
for answer in ocrinfo['answers']:
if "words" in answer or answer['key_cls'][0:8] == "question" or answer['key_cls'][0:6] == "answer":
continue
words = get_word_list(answer['transcription'])
wordboxes = splitBox(answer['points'],words)
answer['words']=wordboxes
# 组合成ppocrlabel格式
txt_document = []
for ocrinfo in ocr_infos:
if 'words' in ocrinfo:
for word in ocrinfo['words']:
txt_document.append(word)
del ocrinfo['words']
if ocrinfo['key_cls'] in ["other","None"] :
# None、other
ocrinfo["key_cls"] = "other"
if "construct" in ocrinfo:
del ocrinfo["construct"]
if "answers" in ocrinfo:
del ocrinfo['answers']
if "words" in ocrinfo:
del ocrinfo['words']
txt_document.append(ocrinfo)
lines +=image_path+"\t"+json.dumps(txt_document,ensure_ascii=False)+"\n"
with open(newlabelfile,'w+',encoding='utf-8') as f2:
f2.writelines(lines)
# 切割文字和kie
make_qa_linking('/home/aistudio/XTOWER/Label.txt','/home/aistudio/XTOWER/Label_kie_word.txt',documents,split_text=True)
# 可忽略切割的文字,可用来导出文本识别结果。
make_qa_linking('/home/aistudio/XTOWER/Label.txt','/home/aistudio/XTOWER/Label_kie_no_word.txt',documents,split_text=False,ignore_word=True)
# 生成的新文件都是可以直接在PPOCRLabel中打开的
Label_kie_word.txt 也是文本检测格式,也可以直接在PPOCRLabel中打开,覆盖Label.txt即可。

切割文本后在PPOCRLabel中效果
标注技巧
- 中英文和特殊符号的长文本可以分段标注,自动切割误差小。
- 切割文本后的文件可导入到PPOCRLabel中进行微调(可直接覆盖Label.txt).
- 如果header识别错误,需要先纠错,防止不能匹配到定义的文档结构上
Label_kie_not_word.txt 不包含切割的key_cls="word"的标注框,可用来PPOCRLabel“导出识别结果”功能制作文本识别数据集。
3.4 划分数据集
将QA重命名和切割文字的Label_kie_word.txt文件再划分数据集,这里就简单的划分为7:3
# 划分文本检测的训练集和验证集
import random
import os
import shutil
image_path = "/home/aistudio/XTOWER/"
det_gt_kie_file = "/home/aistudio/XTOWER/Label_kie_word.txt"
det_gt_train = "/home/aistudio/XTOWER/train_data/det_gt_train.txt" #此时还不是最终的目标检测格式
det_gt_val = "/home/aistudio/XTOWER/val_data/det_gt_val.txt"
train_dir = os.path.dirname(det_gt_train)
val_dir = os.path.dirname(det_gt_val)
if os.path.isdir(train_dir):
shutil.rmtree(train_dir)
os.mkdir(train_dir)
os.mkdir(train_dir+'/image')
if os.path.isdir(val_dir):
shutil.rmtree(val_dir)
os.mkdir(val_dir)
os.mkdir(val_dir+'/image')
newinfo = {}
i = 0
json_lines = ""
with open(det_gt_kie_file) as f:
lines = f.readlines();
random.shuffle (lines)
list_len = len(lines)
train_len= int(0.7 * list_len)
train_data = lines[:train_len]
val_data = lines[train_len:]
# print(list_len)
# print(len(train_data))
# print(len(val_data))
with open(det_gt_train,'w+',encoding='utf-8') as f1:
f1.writelines(train_data)
for line in train_data:
image_file,_ = line.split("\t");
shutil.move(image_path+image_file,train_dir+'/image/')
with open(det_gt_val,'w+',encoding='utf-8') as f2:
f2.writelines(val_data)
for line in val_data:
image_file,_ = line.split("\t");
shutil.move(image_path+image_file,val_dir+'/image/')
# 划分文本识别的训练集和验证集
import random
import os
import shutil
image_path = "/home/aistudio/XTOWER/"
rec_file = "/home/aistudio/XTOWER/rec_gt.txt"
rec_train = "/home/aistudio/XTOWER/rec_train.txt"
rec_val = "/home/aistudio/XTOWER/rec_val.txt"
train_dir = os.path.dirname(rec_train)
val_dir = os.path.dirname(rec_val)
newinfo = {}
i = 0
json_lines = ""
with open(rec_file) as f:
lines = f.readlines();
random.shuffle (lines)
list_len = len(lines)
train_len= int(0.8 * list_len)
train_data = lines[:train_len]
val_data = lines[train_len:]
# print(list_len)
# print(len(train_data))
# print(len(val_data))
with open(rec_train,'w+',encoding='utf-8') as f1:
f1.writelines(train_data)
with open(rec_val,'w+',encoding='utf-8') as f2:
f2.writelines(val_data)
3.5 文本检测格式转DOC-VQA格式
import json
from PIL import Image
import re
import copy
# 文本检测格式转DOC-VQA格式
def det_gt_kie_vqa(det_gt_file = 'Label.txt',normalize_file='normalize.json',dataset_path='../train_data'):
file = open(det_gt_file)
newinfo = {}
i = 0
lines = ""
while True:
line = file.readline()
if not line:
break
image_path,ocr_info = line.split("\t");
# image_id = image_path[-8:][:4] #这里取文件明的后4位(不含扩展名)
img = Image.open(dataset_path+"/"+image_path)
newinfo['width'] = img.width
newinfo['height'] = img.height
ocr_infos = json.loads(ocr_info)
# 排除word类型
all_word_info=[]
qa=[]
for ocr_info in ocr_infos:
if ocr_info['key_cls'] == "word":
words_dict={
'box':[ocr_info['points'][0][0],
ocr_info['points'][0][1],
ocr_info['points'][2][0],
ocr_info['points'][2][1]],
'text':ocr_info['transcription']
}
all_word_info.append(copy.deepcopy(words_dict))
else:
qa.append(copy.deepcopy(ocr_info))
links={}
ocr_id=1
# 分别提取qa
q=[]
a=[]
for ocr_info in qa:
ocr_info['id'] = ocr_id
if ocr_info['key_cls'][0:8] == "question":
question_id = ocr_info['key_cls'].replace("question_","")
q.append([question_id,ocr_id])
elif ocr_info['key_cls'][0:6] == "answer":
question_answer_id = ocr_info['key_cls'].replace("answer_","")
a.append([question_answer_id,ocr_id])
ocr_id=ocr_id+1
# qa关系
for cls_id,ocrid in q:
link=[]
for cls_id2,ocrid2 in a:
if cls_id == cls_id2:
link.append([ocrid,ocrid2])
links[cls_id]=link
newocrinfos = []
ocr_id = 1
for ocr_info in qa:
question_id = 0
newocrinfo={}
newocrinfo['text'] = ocr_info['transcription']
newocrinfo['bbox'] = [
ocr_info['points'][0][0],
ocr_info['points'][0][1],
ocr_info['points'][2][0],
ocr_info['points'][2][1],
]
newocrinfo['id'] = ocr_id
if ocr_info['key_cls'][0:8] == "question":
question_id = ocr_info['key_cls'].replace("question_","")
newocrinfo['label'] = 'question'
elif ocr_info['key_cls'][0:6] == "answer":
question_id = ocr_info['key_cls'].replace("answer_","")
newocrinfo['label'] = "answer"
elif ocr_info['key_cls'] == 'header':
newocrinfo['label'] ="header"
else:
# newocrinfo['label'] = ocr_info['key_cls']
newocrinfo['label'] = "other"
if question_id !='0' and question_id in links :
newocrinfo['linking'] = links[question_id]
else:
newocrinfo['linking'] = []
# 分析word的从属关系
words=[]
for word_info in all_word_info:
word_center_x = word_info['box'][0] + int((word_info['box'][2] - word_info['box'][0]) / 2)
word_center_y = word_info['box'][1] + int((word_info['box'][3] - word_info['box'][1]) / 2)
#如果word的中心点落在 question或answer标签内
if newocrinfo['bbox'][0] < word_center_x < newocrinfo['bbox'][2] and newocrinfo['bbox'][1] < word_center_y < newocrinfo['bbox'][3] :
words.append(word_info)
newocrinfo['words'] = words
newocrinfos.append(newocrinfo)
# if newocrinfo['label'] =='answer':
# print('ocr_id=',ocr_id)
# print(newinfo)
ocr_id = ocr_id+1
newinfo['ocr_info'] = newocrinfos
# print(newocrinfos)
# break
# print(all_word_info)
# break
lines += image_path +"\t" + json.dumps(newinfo,ensure_ascii=False) + "\n"
i=i+1
with open(normalize_file,'w+',encoding='utf-8') as f2:
# file2.seek(0) # 移动指针到开头
f2.writelines(lines)
# 生成ser和re可训练的格式
det_gt_kie_vqa(det_gt_file='/home/aistudio/XTOWER/train_data/det_gt_train.txt',normalize_file='/home/aistudio/XTOWER/train_data/normalize_train.json',dataset_path='/home/aistudio/XTOWER/train_data')
det_gt_kie_vqa(det_gt_file='/home/aistudio/XTOWER/val_data/det_gt_val.txt',normalize_file='/home/aistudio/XTOWER/val_data/normalize_val.json',dataset_path='/home/aistudio/XTOWER/val_data')
import shutil
di_xtower = set()
def to_det_gt(filename,di=set()):
"""
将kie标注格式转为的文本检测格式,并返回全部的文本
"""
new_docs = ""
with open(filename, "r", encoding='utf-8') as f:
docs = f.readlines()
for doc in docs:
image_file,ocr_info = doc.split("\t");
ocr_infos = json.loads(ocr_info)
txt_document=[]
for ocr_info in ocr_infos:
if "key_cls" not in ocr_info or ocr_info['key_cls'] != 'word':
txt_document.append({'transcription':ocr_info['transcription'],'points':ocr_info['points']})
# 字典
di = di | set(ocr_info["transcription"])
new_docs += image_file+"\t"+json.dumps(txt_document,ensure_ascii=False)+"\n"
with open(filename, "w", encoding='utf-8') as f:
f.writelines(new_docs)
return di
di_xtower=to_det_gt("/home/aistudio/XTOWER/train_data/det_gt_train.txt",di_xtower)
di_xtower=to_det_gt("/home/aistudio/XTOWER/val_data/det_gt_val.txt",di_xtower)
# 字典处理
baseline_label = '/home/aistudio/PaddleOCR/ppocr/utils/ppocr_keys_v1.txt'
shutil.copyfile(baseline_label, '/home/aistudio/XTOWER/word_dict.txt')
with open(baseline_label, 'r', encoding='utf-8') as f:
all_chars = f.read()
with open('/home/aistudio/XTOWER/xtower_dict.txt', 'w', encoding='utf-8') as f:
for char in di_xtower:
f.write(char+'\n')
with open('/home/aistudio/XTOWER/word_dict.txt', 'a', encoding='utf-8') as f:
f.write('\n')
for char in di_xtower:
if char not in all_chars:
f.write(char+'\n')
print(char)
️
☑
4. 查看效果
这里只显示预测效果,如何训练ser和re,可参考我写的《多模态技术在工业场景中的应用实践:表单识别重命名》
)]
5. 总结
5.1 优点
- 大大减少标注时间,全中文文本,基本不用调整。
- 可以在切割后的文件里继续标注,不会被重复切割。因已经设置特定的key_cls,不再切割标注的文本。
- 切割后的文字会自动归属到所属的文本里。
- 支持answer为多行文本时自动question关联。
5.2 缺点
- 中英文及符号较多的句子效果不佳,还需要手动调整切割后word的位置。
- 对倾斜度较大的文本自动切割效果不佳。
- question和answer自动关联机制比较简单,通用性有局限。
5.3 改进
- 通过训练识别单个字(汉字、中文符号,英文单词)的方式,定位单个字的坐标位置、角度等来进行自动标注。
- 期待PPOCRLabel支持自动切割文本和RE关联。
作者介绍:tianxingxia
原项目链接:https://aistudio.baidu.com/aistudio/projectdetail/4197468