yolov8训练CDLA数据文版版面分析

本文介绍了CDLACDLA中文文档版面分析数据集,涉及数据集结构、json转txt预处理步骤,并展示了如何使用ultralytics的yolov8模型进行训练和测试,包括模型加载、训练配置和图像检测应用。
摘要由CSDN通过智能技术生成

一.数据集介绍

CDlA数据集介绍:CDLA
CDLA是一个中文文档版面分析数据集,面向中文文献类(论文)场景。包含以下10个label:在这里插入图片描述
数据量:
共包含5000张训练集和1000张验证集,分别在train和val目录下。每张图片对应一个同名的标注文件(.json)。
数据展示:在这里插入图片描述
标注工具是labelme,所以标注格式和labelme格式一致。
数据结构:在这里插入图片描述
train和val里面分别存放图片及标注结果json文件
在这里插入图片描述

二. 数据预处理

将json文件转换成txt文件

import json 
import os 
import argparse
from tqdm import tqdm
import glob
import cv2 
import numpy as np
 
def convert_label_json(json_dir,save_dir,classes):
    files=os.listdir(json_dir)
    #删选出json文件
    jsonFiles=[]
    for file in files:
        if os.path.splitext(file)[1]==".json":
            jsonFiles.append(file)
    #获取类型        
    classes=classes.split(',')
    
    #获取json对应中对应元素
    for json_path in tqdm(jsonFiles):
        path=os.path.join(json_dir,json_path)
        with open(path,'r') as loadFile:
            print(loadFile)
            json_dict=json.load(loadFile)
        h,w=json_dict['imageHeight'],json_dict['imageWidth']
        txt_path=os.path.join(save_dir,json_path.replace('json','txt'))
        txt_file=open(txt_path,'w')
        
        for shape_dict in json_dict['shapes']:
            label=shape_dict['label'] 
            label_index=classes.index(label)
            points=shape_dict['points']
            points_nor_list=[]
            for point in points:
                points_nor_list.append(point[0]/w)
                points_nor_list.append(point[1]/h)
            points_nor_list=list(map(lambda x:str(x),points_nor_list))
            points_nor_str=' '.join(points_nor_list)
            label_str=str(label_index)+' '+points_nor_str+'\n'
            txt_file.writelines(label_str)
            
            
if __name__=="__main__":
    parser=argparse.ArgumentParser(description="json convert to txt params")
    #设json文件所在地址
    parser.add_argument('-json',type=str,default='cdla_data/label_data/val',help='json path')
    #设置txt文件保存地址
    parser.add_argument('-save',type=str,default='layout_analysis/cdla_data/val',help='save path')
    #设置label类型,用“,”分隔
    parser.add_argument('-classes',type=str,default='Header,Text,Reference,Figure caption,Figure,Table caption,Table,Title,Footer,Equation',help='classes')
    args=parser.parse_args()
    print(args.json,args.save,args.classes)
    convert_label_json(args.json,args.save,args.classes)

在这里插入图片描述

三.yoloV8模型环境搭建

采用ultralytics集成的代码进行训练:ultralytics

pip install ultralytics

yolov8预训练权重:yolov8预训练权重

四.模型训练

import sys
import os
sys.path.insert(0, os.path.dirname(os.getcwd()))
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

from ultralytics import YOLO

def train_model():
    # 加载模型
    # model = YOLO("yolov8n.yaml")  # 从头开始构建新模型
    #print('model load。。。')
    model = YOLO("8npt/best.pt")  # 加载模型
    print('model load completed。。。')

    # 使用模型
    model.train(data="img-layout.yaml", epochs=300, device=1)# , lr0=0.0001)  # 训练模型
    metrics = model.val()  # 在验证集上评估模型性能
    print('metric : {}'.format(metrics))
    # results = model("https://ultralytics.com/images/bus.jpg")  # 对图像进行预测
    success = model.export(format="onnx")  # 将模型导出为 ONNX 格式

if __name__ == '__main__':
    train_model()

以上参数解释如下:

task:选择任务类型,可选[‘detect’, ‘segment’, ‘classify’, ‘init’]

mode: 选择是训练、验证还是预测的任务蕾西 可选[‘train’, ‘val’, ‘predict’]

model: 选择yolov8不同的模型配置文件,可选yolov8s.yaml、yolov8m.yaml、yolov8l.yaml、yolov8x.yaml

data: 选择生成的数据集配置文件

epochs:指的就是训练过程中整个数据集将被迭代多少次,显卡不行你就调小点。

batch:一次看完多少张图片才进行权重更新,梯度下降的mini-batch,显卡不行你就调小点。

五.模型测试

import os
import cv2
import sys

sys.path.insert(0, os.path.dirname(os.getcwd()))
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

from ultralytics import YOLO
def infer(image_dir):
  """
  检测一个文件夹下的所有图片

  Args:
    image_dir: 图片文件夹路径

  Returns:
    None
  """

  model = YOLO('train2/weights/best.pt') #模型

  for filename in os.listdir(image_dir):
    image_path = os.path.join(image_dir, filename)
    results = model(image_path)
    print(results[0].plot())
    cv2.imwrite('test_result_v1/' + filename, results[0].plot()) #保存地址

if __name__ == '__main__':
  image_dir = 'test_data_20240304_pic/' #图片数据
  infer(image_dir)

测试效果:在这里插入图片描述
在这里插入图片描述

  • 11
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值