Paddle的OCR标签转化为TXT格式

之前一直用别人造好的,对与xml的转化不甚了解,经过几小时奋战,总算写出了能用点的代码,本着能用就行,代码简单不美观0.0

'''
2022.5.25 OCR xml转txt
'''
import os
import glob
import shutil
import re
import os
import json
import xml.etree.ElementTree as ET
def horizon_rect2box(xmin, ymin, xmax, ymax):
    p0x, p0y = xmin, ymin
    p1x, p1y = xmax, ymin
    p2x, p2y = xmin, ymax
    p3x, p3y = xmax, ymax
    points = [(p0x, p0y), (p1x, p1y), (p2x, p2y), (p3x, p3y)]
    return points

tpath= r'D:\yanyi\project_process\ocr\data/'
labeltxt = r'D:\yanyi\project_process\ocr/label.txt'
filelist = glob.glob(os.path.join(tpath, "*.xml"))
save_file = open(labeltxt, "w")
for fp in filelist:
    root = ET.parse(os.path.join(tpath, fp)).getroot()

    fi=root.find('filename')
    size = root.find('size')
    data = []
    fileName = root.find('filename').text
    for child in root.findall('object'):  # 找到图片中的所有框
        label = child.find('name').text
        ob_name = child.find('name')
        if (child.find('bndbox')):  # 找到框的标注值并进行读取
            ob_box = child.find('bndbox')
            xmin = ob_box.find('xmin').text
            ymin = ob_box.find('ymin').text
            xmax = ob_box.find('xmax').text
            ymax = ob_box.find('ymax').text
            bndbox = horizon_rect2box(xmin, ymin, xmax, ymax)
            dic = {"transcription": label, "points": bndbox}
            data.append(dic)
        if (child.find('polygon')):
            ob_gon = child.find('polygon')
            if(len(ob_gon)==8):
                x1 ,y1= ob_gon.find('x1').text,ob_gon.find('y1').text
                x2 ,y2= ob_gon.find('x2').text,ob_gon.find('y2').text
                x3 ,y3= ob_gon.find('x3').text,ob_gon.find('y3').text
                x4 ,y4= ob_gon.find('x4').text,ob_gon.find('y4').text
                # x5 ,y5= ob_gon.find('x5').text,ob_gon.find('y5').text
                polygon_box = [(x1, y1), (x2, y2), (x3, y3), (x4, y4)]
                dic = {"transcription": label, "points": polygon_box}
                data.append(dic)
            if (len(ob_gon) == 10):
                x1, y1 = ob_gon.find('x1').text, ob_gon.find('y1').text
                x2, y2 = ob_gon.find('x2').text, ob_gon.find('y2').text
                x3, y3 = ob_gon.find('x3').text, ob_gon.find('y3').text
                x4, y4 = ob_gon.find('x4').text, ob_gon.find('y4').text
                x5 ,y5= ob_gon.find('x5').text,ob_gon.find('y5').text
                polygon_box = [(x1, y1), (x2, y2), (x3, y3), (x4, y4),(x5, y5)]
                dic = {"transcription": label, "points": polygon_box}
                data.append(dic)
            if (len(ob_gon) == 12):
                x1, y1 = ob_gon.find('x1').text, ob_gon.find('y1').text
                x2, y2 = ob_gon.find('x2').text, ob_gon.find('y2').text
                x3, y3 = ob_gon.find('x3').text, ob_gon.find('y3').text
                x4, y4 = ob_gon.find('x4').text, ob_gon.find('y4').text
                x5 ,y5= ob_gon.find('x5').text,ob_gon.find('y5').text
                x6 ,y6= ob_gon.find('x6').text,ob_gon.find('y6').text
                polygon_box = [(x1, y1), (x2, y2), (x3, y3), (x4, y4),(x5, y5), (x6, y6)]
                dic = {"transcription": label, "points": polygon_box}
                data.append(dic)
            if (len(ob_gon) == 14):
                x1, y1 = ob_gon.find('x1').text, ob_gon.find('y1').text
                x2, y2 = ob_gon.find('x2').text, ob_gon.find('y2').text
                x3, y3 = ob_gon.find('x3').text, ob_gon.find('y3').text
                x4, y4 = ob_gon.find('x4').text, ob_gon.find('y4').text
                x5 ,y5= ob_gon.find('x5').text,ob_gon.find('y5').text
                x6 ,y6= ob_gon.find('x6').text,ob_gon.find('y6').text
                x7 ,y7= ob_gon.find('x7').text,ob_gon.find('y7').text
                polygon_box = [(x1, y1), (x2, y2), (x3, y3), (x4, y4),(x5, y5), (x6, y6), (x7, y7)]
                dic = {"transcription": label, "points": polygon_box}
                data.append(dic)
        # print(data)
        # print(json.dumps(data))
    line = fileName + "\t" + json.dumps(data) + '\n'
    print(line)
        # save_file.writelines(line)
    # with open(os.path.join(labeltxt, fp.split('.')[0] + '.txt'), 'a+') as f:
    #     f.write(sz)

这里由于存在多边形标注,我的数据集里总共有8、10、12、14四种,因此就简单的判断了不同的情况,最好保存到一个txt文件里。

  • 3
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

一休哥※

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值