从零开始使用Surya-OCR——项目源码拆解

目录

一、Surya模型检测使用Python接口中的源码详解

        1.选择模型检测GPU

        2.配置加载模型参数

        3.批量检测图片

         4.检测输出结果源码解读

        5.批量信息的保存和可视化


一、Surya模型检测使用Python接口中的源码详解

        使用surya源码进行模型检测的过程中, 模型的各种参数设置、环境变量配置都写在下载的源码下 /surya/settings.py 文件中,修改其中的参数即可实现全局的配置。

        1.选择模型检测GPU

        修改 settings.py 文件中的 TORCH_DEVICE 参数,默认为 None 时,运行监测代码会自动检查当前设备,并选择索引顺序最小的GPU运行——‘cuda:0’。在实际部署中,如果服务器中第一块GPU——‘cuda:0’有其他模型在跑,可能需要调整模型预测位置,将模型放到另一块GPU上运行。只需修改以下参数代码。

# 指定模型所在GPU
TORCH_DEVICE: Optional[str] = 'cuda:1'

        2.配置加载模型参数

        在 settings.py 文件中模型参数默认为在线加载,当服务器无法连接外部网络时,离线部署加载模型参数需调整设置中的地址。需根据自己模型存放位置,修改为下面的参数字符串。

## 地址需改为你存放模型的绝对地址
# 文本行检测模型
DETECTOR_MODEL_CHECKPOINT: str = "//Surya-OCR/hugging_model/surya_det2"

# 文本区域检测模型
LAYOUT_MODEL_CHECKPOINT: str = "//Surya-OCR/hugging_model/surya_layout"

        默认在线加载模型参数位置:

        参数修改内容:(为你存放离线下载模型地址)

        测试模型是否加载成功的代码如下:

from surya.model.detection.segformer import load_model, load_processor
from surya.settings import settings

# 行检测模型:surya_det_2
det_model = load_model()
det_processor =  load_processor()
print('det2_model load success')

# 区域检测:surya_layout
model = load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
processor = load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
print('layout_model load success')

        修改后,检验模型加载情况:

        3.批量检测图片

        实际部署中,官方文档提供 Python 接口检测的代码对单个图片检测顺利,但在批量检测图片集——文件夹时报错。 (单图代码写在上一篇博文《从零开始使用Surya-OCR——文本目标检测模型的安装与部署》中,具体可参考https://blog.csdn.net/qq_58718853/article/details/137150986

        第一个报错是,官方提供接口函数,无法读取文件夹内图片,报读取文件权限被拒。暂未实现直接解决该问题的办法。参看后续 batch_text_detection 源代码传参信息,得知图片读取后的传入函数结果是一个列表,可以选择替代方案实现同等效果,替代方案代码如下。

import os
from PIL import Image
from surya.detection import batch_text_detection
from surya.model.detection.segformer import load_model, load_processor

IMAGE_PATH = 'image_path'

model = load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
processor = load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
print('model load success')

# 批量将文件夹图片读入images列表中
images = []
for file in os.listdir(IMAGE_PATH):
    image_path = os.path.join(IMAGE_PATH, file)
    image = Image.open(image_path)
    images.append(image)

predictions = batch_text_detection(images, model, processor)
print(predictions)

        使用新代码后,原本的报错问题解决了,但出现了新的报错。

        第二个报错是,检查surya模型做批量预测任务时,在得到模型输出后,还会对多张图片的结果进行多进程处理。问题就在多进程处理时,源代码内置函数重复调用主函数——导致模型本来只需加载一次,此时不断加载模型致使GPU显存爆了。我们找到对应报错的源码处检查。

        按住 Ctrl 并点击报错的 batch_text_detection 函数,即可进入源码处,检查研究发现可能是Windows系统和 Linux 系统对于多进程和多线程的解释存在差异,本机是windows系统, 可能在导入 ProcessPoolExecutor 函数时,将我的主函数视为多进程对象,不断创建新的进程。但是此处希望实现的是中间过程结果的多线程处理,而不影响主函数。因此需要将 ProcessPoolExecutor 改为使用多线程的 ThreadPoolExecutor,问题即可解决。       

        经过上述所有源代码修改处理,成功运行主函数,得到surya模型批量检测图片后得到的框信息结果,并将其打印出来。

         4.检测输出结果源码解读

        Surya模型输出是自定义类的数据格式,下面根据其官方文档和项目源码解读其输出的格式,以方便后续对输出的处理,提取出所需的数据信息。

        官方文档:https://github.com/VikParuchuri/surya

        Surya 模型有三种预测模式——OCR & Text Line & Layout,对应三种模型输出的格式,每种模式的输出都是以类的形式定义的。下面重点放在 Text Line 文本行检测和 Layout 区域检测的源码信息解读上。   

 ①文本行检测的模型输出——Text Line

        将与输出相关的源代码从项目中单独提取出来看,下面是输出的基础类,即每个图片模型预测后的信息都封装在了 TextDetectionResult 里面。

"文本行检测"

# 输出的基础类
class TextDetectionResult(BaseModel):
    bboxes: List[PolygonBox]
    vertical_lines: List[ColumnLine]
    horizontal_lines: List[ColumnLine]
    heatmap: Any
    affinity_map: Any
    image_bbox: List[float]

        下面分别解释输出基础类中的具体信息都是怎么定义的,从源码中提出相关代码。

输出的第一个类信息:PolygonBox 注解

(下面非完整代码,为清晰类输出含义,只选取主要功能代码)

# 输出框信息类
class PolygonBox(BaseModel):
    polygon: List[List[float]]            ## 存储框四个角——全坐标
    confidence: Optional[float] = None    ## 框预测置信度

    def bbox(self) -> List[float]:
        box = [self.polygon[0][0], self.polygon[0][1], self.polygon[1][0], self.polygon[2][1]]
        if box[0] > box[2]:
            box[0], box[2] = box[2], box[0]
        if box[1] > box[3]:
            box[1], box[3] = box[3], box[1]
        return box                       ## 存储框左上右下——对角坐标

        通过源码可知,此类保存的是检测框的坐标信息和置信度,这是预测中的主要信息。使用具体的模型输出结果,可以更清楚该类的输出形状。输出TextDetectionResult 包含多个类信息,其中定义的 bboxes —— PolygonBox 存储的是一张图片内检测出来的所有框,而每个框的信息结构是包含三个子类:全坐标、置信度和对角坐标

输出的第二个类信息:ColumnLine 注解 

(下面非完整代码,为清晰类输出含义,只选取主要功能代码)

# 图片中线的检测
class ColumnLine(Bbox):
    vertical: bool      # 垂直线:有-True;无-False
    horizontal: bool    # 水平线:有-True;无-False

# 检测线的框坐标
class Bbox(BaseModel):
    bbox: List[float]

        同样通过源码可知,输出的第二个类信息 ColumnLine 是用来保存在模型预测中对图片检测到的水平线、垂直线的信息,如果检测到了那么就同时保存其框位置。通过具体模型输出对应位置检索,可以清晰理解。

输出的剩余类信息:heatmap、affinity_map、image_bbox

        剩余的类信息是对图片结果输出的补充,我们可以直接将其打印出来,看看其内容。最容易看出来的是 image_bbox ,其实际就是用一个框将整个图片框起来,然后返回对角坐标。而heatmap、affinity_map 则是 PIL.Image.Image 的类信息。

 ②文本区域检测的模型输出 ——Layout

          同文本行检测一样,废话少说,直接上代码和图。

# 区域检测输出
class LayoutResult(BaseModel):
    bboxes: List[LayoutBox]
    segmentation_map: Any
    image_bbox: List[float]、

class LayoutBox(PolygonBox):
    label: str             ## 多了一个区域类别的预测

 

        5.批量信息的保存和可视化

        将surya模型自定义的类输出转化为 json 格式保存到指定文件夹,代码如下。

import os
import json
from PIL import Image
from surya.detection import batch_text_detection
from surya.model.detection.segformer import load_model, load_processor
import cv2
import numpy as np

IMAGE_PATH = 'iamge_path'   ## 检测图片保存地址
json_file = 'json_path'     ## 框json保存地址
checkpoint = 'model_path'   ## 模型参数加载地址
heat_file = 'heat_path'     ## 热图保存

########### 上述为修改部分,根据实际地址填入,下面无需修改   ############

model, processor = load_model(checkpoint=checkpoint), load_processor(checkpoint=checkpoint)
print('model load success')

# 模型预测
images = []
image_name = []
for file in os.listdir(IMAGE_PATH):
    image_path = os.path.join(IMAGE_PATH, file)
    image = Image.open(image_path)
    images.append(image)
    image_name.append(file)

predictions = batch_text_detection(images, model, processor)
print('predict success')

# 保存模型结果
## 类型转为json
def class_to_json(bboxes, file, box_type=True):
    json_list = []
    for i, bbox in enumerate(bboxes):
        if box_type:
            json_dict = dict()
            box = bbox.bbox
            box.append(bbox.confidence)
            json_dict["id"] = i
            json_dict["name"] = file
            json_dict["box"] = box
            json_list.append(json_dict)
        else:
            json_dict = dict()
            box = bbox.bbox
            json_dict["id"] = i
            json_dict["name"] = file
            json_dict["box"] = box
            json_list.append(json_dict)
    return json_list

## 保存到指定文件夹
def save_json(json_list, json_path):
    with open(json_path, 'w') as f:
        json.dump(json_list, f)

## 主函数
def save_predict(predictions, image_name, heat_file):
    for i, pred in enumerate(predictions):
        # 框信息保存
        bboxes = pred.bboxes
        vertical = pred.vertical_lines
        horizontal = pred.horizontal_lines
        file = image_name[i]
        bboxes_json = class_to_json(bboxes, file)
        vertical_json = class_to_json(vertical, file, box_type=False)
        horizontal_json = class_to_json(horizontal, file, box_type=False)

        basename = file.split('.')[0]
        save_json(bboxes_json, os.path.join(json_file+'box/', basename + '.json'))
        save_json(vertical_json, os.path.join(json_file + 'vertical/', basename + '.json'))
        save_json(horizontal_json, os.path.join(json_file + 'horizontal/', basename + '.json'))

        # 热图调参信息保存
        heatmap = pred.heatmap
        img = cv2.cvtColor(np.asarray(heatmap), cv2.COLOR_RGB2BGR)
        cv2.imwrite(heat_file+basename+'.jpg', img)
        print(basename + '  success')

if __name__ == '__main__':
    save_predict(predictions, image_name, heat_file)

        可视化框的代码如下。

import os
import json
import cv2

# jpg、json、vis文件位置
jpg_path = 'JPG'
json_path = 'JSON'
vis_path = 'VIS'

########### 上述为修改部分,根据实际地址填入,下面无需修改   ############

# 可视化锚框
## 锚框展示细节
def hsv2bgr(h, s, v):
    h_i = int(h * 6)
    f = h * 6 - h_i
    p = v * (1 - s)
    q = v * (1 - f * s)
    t = v * (1 - (1 - f) * s)
    r, g, b = 0, 0, 0
    if h_i == 0:
        r, g, b = v, t, p
    elif h_i == 1:
        r, g, b = q, v, p
    elif h_i == 2:
        r, g, b = p, v, t
    elif h_i == 3:
        r, g, b = p, q, v
    elif h_i == 4:
        r, g, b = t, p, v
    elif h_i == 5:
        r, g, b = v, p, q
    return int(b * 255), int(g * 255), int(r * 255)

def random_color(id):
    h_plane = (((id << 2) ^ 0x937151) % 100) / 100.0
    s_plane = (((id << 3) ^ 0x315793) % 100) / 100.0
    return hsv2bgr(h_plane, s_plane, 1)

# 可视化主函数
def visualize(json_path, jpg_path, vis_path, box_type=True):
    if box_type:
        for file in os.listdir(json_path):
            with open(json_path+file,'r') as f:
                drawResult = json.load(f)
            basefile = file.split('.')[0]
            jpg_file = os.path.join(jpg_path,basefile+".jpg")
            img = cv2.imread(jpg_file)
            for idx, result in enumerate(drawResult):
                left, top, right, bottom = int(result['box'][0]), int(result['box'][1]), int(result['box'][2]), int(result['box'][3])
                label = int(result['box'][4])
                color = random_color(1)
                cv2.rectangle(img, (left, top), (right, bottom), color=color ,thickness=2, lineType=cv2.LINE_AA)
                caption = f"{'ZW'}"
                w, h = cv2.getTextSize(caption, 0, 1, 2)[0]
                cv2.rectangle(img, (left - 3, top - 33), (left + w + 10, top), color, -1)
                cv2.putText(img, caption, (left, top - 5), 0, 1, (0, 0, 0), 2, 16)
            save_file = os.path.join(vis_path, basefile+".jpg")
            print(save_file)
            cv2.imwrite(save_file, img)
    else:
        for file in os.listdir(json_path):
            with open(json_path+file,'r') as f:
                drawResult = json.load(f)
            basefile = file.split('.')[0]
            jpg_file = os.path.join(jpg_path,basefile+".jpg")
            img = cv2.imread(jpg_file)
            for idx, result in enumerate(drawResult):
                left, top, right, bottom = int(result['box'][0]), int(result['box'][1]), int(result['box'][2]), int(result['box'][3])
                color = random_color(1)
                cv2.rectangle(img, (left, top), (right, bottom), color=color ,thickness=2, lineType=cv2.LINE_AA)
                caption = f"{'ZW'}"
                w, h = cv2.getTextSize(caption, 0, 1, 2)[0]
                cv2.rectangle(img, (left - 3, top - 33), (left + w + 10, top), color, -1)
                cv2.putText(img, caption, (left, top - 5), 0, 1, (0, 0, 0), 2, 16)
            save_file = os.path.join(vis_path, basefile+".jpg")
            print(save_file)
            cv2.imwrite(save_file, img)


if __name__ == '__main__':
   visualize(json_path+'box/', jpg_path, vis_path+'box/')
   visualize(json_path + 'vertical/', jpg_path, vis_path + 'vertical/', box_type=False)
   visualize(json_path + 'horizontal/', jpg_path, vis_path + 'horizontal/', box_type=False)

  • 13
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值