在线体验:https://huggingface.co/spaces/opendatalab/DocLayout-YOLO
论文地址:https://arxiv.org/abs/2410.12628
代码地址:https://github.com/opendatalab/DocLayout-YOLO
一. 安装
由于我只想把DocLayout-YOLO当做工具使用,因此使用pip安装
pip install doclayout-yolo
二. 代码封装
import os
import cv2
import torch
from doclayout_yolo import YOLOv10
class DocLayoutYOLO(YOLOv10):
def __init__(self, model_path="./doclayout_yolo_docstructbench_imgsz1024.pt"):
super().__init__(model_path)
self.model.names = {0: 'title', 1: 'plain text', 2: 'abandon', 3: 'figure', 4: 'figure_caption',
5: 'table', 6: 'table_caption', 7: 'table_footnote', 8: 'isolate_formula', 9: 'formula_caption'}
self.model.device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f"Using device: {self.model.device}")
def predict(self, image_path, imgsz=1024, conf=0.5):
results = super().predict(source=image_path, imgsz=imgsz, conf=conf, device=self.model.device)
if results and len(results) > 0:
result = results[0]
boxes = result.boxes.xyxy
cls = result.boxes.cls
conf_scores = result.boxes.conf
orig_img = result.orig_img
class_names = []
for i in range(len(boxes)):
cls_name = self.model.names[int(cls[i])]
class_names.append(cls_name)
return boxes.int().tolist(), class_names, conf_scores.tolist(), orig_img
else:
return torch.int().empty(0, 4).tolist(), [], torch.empty(0).tolist(), None
def main():
"""
测试示例
"""
model = DocLayoutYOLO()
boxes, names, conf, orig_img = model.predict("./test.jpeg")
print(boxes) # 框的坐标
print(names) # 框的类别
print(conf) # 框的置信度
if __name__ == "__main__":
main()
三. 效果