yolov8训练自己的关键点检测模型

本文详细介绍了如何使用YOLOv8训练自己的关键点检测模型,从标注数据集、数据转换、模型训练到效果测试,包括labelme的使用、json转yolo格式、训练集验证集划分、预训练模型下载、训练代码运行以及可能出现的错误和解决办法。
摘要由CSDN通过智能技术生成

参考:
官方教程:https://docs.ultralytics.com/zh/tasks/pose/

1、https://blog.csdn.net/weixin_38807927/article/details/135036450

2、https://blog.csdn.net/WYKB_Mr_Q/article/details/132035597
3、yolov8-pose关键点检测,从数据集制作到训练测试

下载最新yolov8代码

git clone https://github.com/ultralytics/ultralytics.git

一、标注数据集

安装labelme

pip install labelme -i https://pypi.tuna.tsinghua.edu.cn/simple

如果报错

$ labelme
2024-01-31 03:16:20,636 [INFO   ] __init__:get_config:67- Loading config file from: /home/diyun/.labelmerc
QObject::moveToThread: Current thread (0x56471fd8b1e0) is not the object's thread (0x564721420cb0).
Cannot move to target thread (0x56471fd8b1e0)

qt.qpa.plugin: Could not load the Qt platform plugin "xcb" in "/home/diyun/anaconda3/envs/pytorch_gpu/lib/python3.8/site-packages/cv2/qt/plugins" even though it was found.
This application failed to start because no Qt platform plugin could be initialized. Reinstalling the application may fix this problem.

Available platform plugins are: xcb, eglfs, linuxfb, minimal, minimalegl, offscreen, vnc, wayland-egl, wayland, wayland-xcomposite-egl, wayland-xcomposite-glx, webgl.


原因是pyqt5版本过高导致,指定版本:5.12.0

pip uninstall pyqt5
pip install pyqt5==5.12.0 -i https://pypi.tuna.tsinghua.edu.cn/simple

先取消“保存图片数据”(减少标注文件大小);在文件下在这里插入图片描述

创建labels.txt文件,内存放物体类别名称与关键点名称

cd /media/diyun/T9/diyun/10_train_data/9_Chinese_chess/key_point/24_0125_train_184
touch label.txt
__ignore__
_background_
checkerboard
point_0
point_1
point_2
point_3

进入创建的labels.txt存在的文件夹下,输入以下命令打开labelme

labelme --labels labels.txt

鼠标点击右键,出现菜单栏,选择Create Retctangle,将需要检测的物体用矩形框框起来,然后给矩形框命名,点击弹出的框里面的命名,然后点击OK,Group ID可根据自己具体需求进行处理。

在这里插入图片描述
右键出现菜单栏,点击Create Point ,然后左键点击需要标注的位置并命名,随后点击OK。
在这里插入图片描述

二、分离img和json 以及json转yolo格式

2.1 json转yolo格式

labelme2yolo.py

# 将labelme标注的json文件转为yolo格式
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import glob
import json
import tqdm
# 物体类别
class_list = ["checkerboard"]
# 关键点的顺序
keypoint_list = ["point_0", "point_1", "point_2", "point_3"]

def json_to_yolo(img_data,json_data):
    h,w = img_data.shape[:2]
    # 步骤:
    # 1. 找出所有的矩形,记录下矩形的坐标,以及对应group_id
    # 2. 遍历所有的head和tail,记下点的坐标,以及对应group_id,加入到对应的矩形中
    # 3. 转为yolo格式

    rectangles = {
   }
    # 遍历初始化
    for shape in json_data["shapes"]:
        label = shape["label"] # pen, head, tail
        group_id = shape["group_id"] # 0, 1, 2, ...
        points = shape["points"] # x,y coordinates
        shape_type = shape["shape_type"]

        # 只处理矩形
        if shape_type == "rectangle":
            if group_id not in rectangles:
                rectangles[group_id] = {
   
                    "label": label,
                    "rect": points[0] + points[1],  # Rectangle [x1, y1, x2, y2]
                    "keypoints_list": []
                }
    # 遍历更新,将点加入对应group_id的矩形中
    for keypoint in keypoint_list:
        for shape in json_data["shapes"]:
            label = shape["label"]
            group_id = shape["group_id"]
            points = shape["points"]
            # 如果匹配到了对应的keypoint
            if label == keypoint:
                rectangles[group_id]["keypoints_list"].append(points[0])
    
    # 转为yolo格式
    yolo_list = []
    for id, rectangle in rectangles.items():
        result_list  = []
        label_id = class_list.index(rectangle["label"])
        # x1,y1,x2,y2
        x1,y1,x2,y2 = rectangle["rect"]
        # center_x, center_y, width, height
        center_x = (x1+x2)/2
        center_y = (y1+y2)/2
        width = abs(x1-x2)
        height = abs(y1-y2)
        # normalize
        center_x /= w
        center_y /= h
        width /= w
        height /= h

        # 保留6位小数
        center_x = round(center_x, 6)
        center_y = round(center_y, 6)
        width = round(width, 6)
        height = round(height, 6)

        # 添加 label_id, center_x, center_y, width, height
        result_list = [label_id, center_x, center_y, width, height]

        # 添加 p1_x, p1_y, p1_v, p2_x, p2_y, p2_v
        for point in rectangle["keypoints_list"]:
            x,y = point
            x,y = int(x), int(y)
            # normalize
            x /= w
            y /= h
            # 保留6位小数
            x = round
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

翟羽嚄

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值