yolov8制作并训练自己的关键点检测模型pose模型(纯小白也能会)

因为最近有一个关键点检测的视觉任务,想要训练一个自己的关键点检测模型,但是按照现有博客去训练总是报错,所有想要自己写一篇训练心得和教程,希望对大家有帮助。

yolo代码准备

YOLOv8官方代码地址:

GitHub - ultralytics/ultralytics: NEW - YOLOv8 🚀 in PyTorch > ONNX > OpenVINO > CoreML > TFLiteNEW - YOLOv8 🚀 in PyTorch > ONNX > OpenVINO > CoreML > TFLite - ultralytics/ultralyticsicon-default.png?t=N7T8https://github.com/ultralytics/ultralytics

训练数据准备

安装labelme(数据集标注软件)

pip install labelme

启动labelme

(因为我的labelme没有添加到环境变量中,所以我启动的时候加上路径)大家正常下载完成之后直接在命令行中输入labelme即可启动后如下图所示

打开我们要标记的数据集图片

开始标记

创建矩形框标记要识别的物体

创建控制点,注意要和前面的标记框的id匹配,标记完成之后的标签显示如下

处理labelme得到的json文件

将json数据集转化为coco数据集格式

import os
import sys
import glob
import json
import argparse
import numpy as np
from tqdm import tqdm
from labelme import utils


class Labelme2coco_keypoints():
    def __init__(self, args):
        """
        Lableme 关键点数据集转 COCO 数据集的构造函数:

        Args
            args:命令行输入的参数
                - class_name 根类名字

        """

        self.classname_to_id = {args.class_name: 1}
        self.images = []
        self.annotations = []
        self.categories = []
        self.ann_id = 0
        self.img_id = 0

    def save_coco_json(self, instance, save_path):
        json.dump(instance, open(save_path, 'w', encoding='utf-8'), ensure_ascii=False, indent=1)

    def read_jsonfile(self, path):
        with open(path, "r", encoding='utf-8') as f:
            return json.load(f)

    def _get_box(self, points):
        min_x = min_y = np.inf
        max_x = max_y = 0
        for x, y in points:
            min_x = min(min_x, x)
            min_y = min(min_y, y)
            max_x = max(max_x, x)
            max_y = max(max_y, y)
        return [min_x, min_y, max_x - min_x, max_y - min_y]

    def _get_keypoints(self, points, keypoints, num_keypoints):
        """
        解析 labelme 的原始数据, 生成 coco 标注的 关键点对象

        例如:
            "keypoints": [
                67.06149888292556,  # x 的值
                122.5043507571318,  # y 的值
                1,                  # 相当于 Z 值,如果是2D关键点 0:不可见 1:表示可见。
                82.42582269256718,
                109.95672933232304,
                1,
                ...,
            ],

        """

        if points[0] == 0 and points[1] == 0:
            visable = 0
        else:
            visable = 1
            num_keypoints += 1
        keypoints.extend([points[0], points[1], visable])
        return keypoints, num_keypoints

    def _image(self, obj, path):
        """
        解析 labelme 的 obj 对象,生成 coco 的 image 对象

        生成包括:id,file_name,height,width 4个属性

        示例:
             {
                "file_name": "training/rgb/00031426.jpg",
                "height": 224,
                "width": 224,
                "id": 31426
            }

        """

        image = {}

        img_x = utils.img_b64_to_arr(obj['imageData'])  # 获得原始 labelme 标签的 imageData 属性,并通过 labelme 的工具方法转成 array
        image['height'], image['width'] = img_x.shape[:-1]  # 获得图片的宽高

        # self.img_id = int(os.path.basename(path).split(".json")[0])
        self.img_id = self.img_id + 1
        image['id'] = self.img_id

        image['file_name'] = os.path.basename(path).replace(".json", ".jpg")

        return image

    def _annotation(self, bboxes_list, keypoints_list, json_path):
        """
        生成coco标注

        Args:
            bboxes_list: 矩形标注框
            keypoints_list: 关键点
            json_path:json文件路径

        """

        if len(keypoints_list) != args.join_num * len(bboxes_list):
            print('you loss {} keypoint(s) with file {}'.format(args.join_num * len(bboxes_list) - len(keypoints_list), json_path))
            print('Please check !!!')
            sys.exit()
        i = 0
        for object in bboxes_list:
            annotation = {}
            keypoints = []
            num_keypoints = 0

            label = object['label']
            bbox = object['points']
            annotation['id'] = self.ann_id
            annotation['image_id'] = self.img_id
            annotation['category_id'] = int(self.classname_to_id[label])
            annotation['iscrowd'] = 0
            annotation['area'] = 1.0
            annotation['segmentation'] = [np.asarray(bbox).flatten().tolist()]
            annotation['bbox'] = self._get_box(bbox)

            for keypoint in keypoints_list[i * args.join_num: (i + 1) * args.join_num]:
                point = keypoint['points']
                annotation['keypoints'], num_keypoints = self._get_keypoints(point[0], keypoints, num_keypoints)
            annotation['num_keypoints'] = num_keypoints

            i += 1
            self.ann_id += 1
            self.annotations.append(annotation)

    def _init_categories(self):
        """
        初始化 COCO 的 标注类别

        例如:
        "categories": [
            {
                "supercategory": "hand",
                "id": 1,
                "name": "hand",
                "keypoints": [
                    "wrist",
                    "thumb1",
                    "thumb2",
                    ...,
                ],
                "skeleton": [
                ]
            }
        ]
        """

        for name, id in self.classname_to_id.items():
            category = {}

            category['supercategory'] = name
            category['id'] = id
            category['name'] = name
            # 17 个关键点数据
            category['keypoint'] = [str(i + 1) for i in range(args.join_num)]

            self.categories.append(category)

    def to_coco(self, json_path_list):
        """
        Labelme 原始标签转换成 coco 数据集格式,生成的包括标签和图像

        Args:
            json_path_list:原始数据集的目录

        """

        self._init_categories()

        for json_path in tqdm(json_path_list):
            obj = self.read_jsonfile(json_path)  # 解析一个标注文件
            self.images.append(self._image(obj, json_path))  # 解析图片
            shapes = obj['shapes']  # 读取 labelme shape 标注

            bboxes_list, keypoints_list = [], []
            for shape in shapes:
                if shape['shape_type'] == 'rectangle':  # bboxs
                    bboxes_list.append(shape)           # keypoints
                elif shape['shape_type'] == 'point':
                    keypoints_list.append(shape)

            self._annotation(bboxes_list, keypoints_list, json_path)

        keypoints = {}
        keypoints['info'] = {'description': 'Lableme Dataset', 'version': 1.0, 'year': 2021}
        keypoints['license'] = ['BUAA']
        keypoints['images'] = self.images
        keypoints['annotations'] = self.annotations
        keypoints['categories'] = self.categories
        return keypoints


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--class_name", default="a", help="class name", type=str)
    parser.add_argument("--input",  default="./json", help="json file path (labelme)", type=str)
    parser.add_argument("--output", default="./coco", help="output file path (coco format)", type=str)
    parser.add_argument("--join_num",  default=2, help="number of join", type=int)
    # parser.add_argument("--ratio", help="train and test split ratio", type=float, default=0.5)
    args = parser.parse_args()

    labelme_path = args.input
    saved_coco_path = args.output

    json_list_path = glob.glob(labelme_path + "/*.json")

    print('{} for json files'.format(len(json_list_path)))
    print('Start transform please wait ...')

    l2c_json = Labelme2coco_keypoints(args)  # 构造数据集生成类

    # 生成coco类型数据
    keypoints = l2c_json.to_coco(json_list_path)
    l2c_json.save_coco_json(keypoints, os.path.join(saved_coco_path, "keypoints.json"))

coco数据集格式转txt

# COCO 格式的数据集转化为 YOLO 格式的数据集
# --json_path 输入的json文件路径
# --save_path 保存的文件夹名字,默认为当前目录下的labels。

import os
import json
from tqdm import tqdm
import argparse

parser = argparse.ArgumentParser()
# 这里根据自己的json文件位置,换成自己的就行
parser.add_argument('--json_path',
                    default='coco/cat.json', type=str,
                    help="input: coco format(json)")
# 这里设置.txt文件保存位置
parser.add_argument('--save_path', default='txt', type=str,
                    help="specify where to save the output dir of labels")
arg = parser.parse_args()


def convert(size, box):
    dw = 1. / (size[0])
    dh = 1. / (size[1])
    x = box[0] + box[2] / 2.0
    y = box[1] + box[3] / 2.0
    w = box[2]
    h = box[3]

    x = round(x * dw, 6)
    w = round(w * dw, 6)
    y = round(y * dh, 6)
    h = round(h * dh, 6)
    return (x, y, w, h)


if __name__ == '__main__':
    json_file = arg.json_path  # COCO Object Instance 类型的标注
    ana_txt_save_path = arg.save_path  # 保存的路径

    data = json.load(open(json_file, 'r'))
    if not os.path.exists(ana_txt_save_path):
        os.makedirs(ana_txt_save_path)

    id_map = {}  # coco数据集的id不连续!重新映射一下再输出!
    with open(os.path.join(ana_txt_save_path, 'classes.txt'), 'w') as f:
        # 写入classes.txt
        for i, category in enumerate(data['categories']):
            f.write(category['name']+"\n")
            id_map[category['id']] = i
    # print(id_map)
    # 这里需要根据自己的需要,更改写入图像相对路径的文件位置。
    # list_file = open(os.path.join(ana_txt_save_path, 'train2017.txt'), 'w')
    for img in tqdm(data['images']):
        filename = img["file_name"]
        img_width = img["width"]
        img_height = img["height"]
        img_id = img["id"]
        head, tail = os.path.splitext(filename)
        ana_txt_name = head + ".txt"  # 对应的txt名字,与jpg一致
        f_txt = open(os.path.join(ana_txt_save_path, ana_txt_name), 'w')
        for ann in data['annotations']:
            if ann['image_id'] == img_id:
                box = convert((img_width, img_height), ann["bbox"])
                f_txt.write("%s %s %s %s %s" % (id_map[ann["category_id"]], box[0], box[1], box[2], box[3]))
                counter=0
                for i in range(len(ann["keypoints"])):
                    if ann["keypoints"][i] == 2 or ann["keypoints"][i] == 1 or ann["keypoints"][i] == 0:
                        f_txt.write(" %s " % format(ann["keypoints"][i] + 1,'6f'))
                        counter=0
                    else:
                        if counter==0:
                            f_txt.write(" %s " % round((ann["keypoints"][i] / img_width),6))
                        else:
                            f_txt.write(" %s " % round((ann["keypoints"][i] / img_height),6))
                        counter+=1
        f_txt.write("\n")
        f_txt.close()

运行完成后我们会得到我们训练时要用到的txt文件,文件如下所示

0 0.440456 0.48896 0.769801 0.889601 0.084046  0.326211  2.000000  0.624217  0.074786  2.000000  0.38604  0.62963  2.000000  0.595157  0.545584  2.000000  0.537607  0.745726  2.000000  0.509117  0.875356  2.000000  0.61453  0.844017  2.000000 

每个数的含义从左到右依次为:

物体类别,物体标记框的中心点坐标x,y标记框的宽和高w,h,后面每三个数据表示一个关键点的坐标和属性(x,y,属性)

关键点属性介绍:0:关键点没有显露出,不可见;1:关键点被遮挡;2:关键点可见

训练数据集分类

代码:

import xml.etree.ElementTree as ET
import pickle
import os
from os import listdir, getcwd
from os.path import join
import random
from shutil import copyfile

# 根据自己的数据标签修改
classes = ["cat"]


def clear_hidden_files(path):
    dir_list = os.listdir(path)
    for i in dir_list:
        abspath = os.path.join(os.path.abspath(path), i)
        if os.path.isfile(abspath):
            if i.startswith("._"):
                os.remove(abspath)
        else:
            clear_hidden_files(abspath)


def convert(size, box):
    dw = 1. / size[0]
    dh = 1. / size[1]
    x = (box[0] + box[1]) / 2.0
    y = (box[2] + box[3]) / 2.0
    w = box[1] - box[0]
    h = box[3] - box[2]
    x = x * dw
    w = w * dw
    y = y * dh
    h = h * dh
    return (x, y, w, h)


# def convert_annotation(image_id):
#     in_file = open('VOCdevkit/VOC2007/Annotations/%s.xml' %image_id)
#     out_file = open('VOCdevkit/VOC2007/YOLOLabels/%s.txt' %image_id, 'w')
#     tree=ET.parse(in_file)
#     root = tree.getroot()
#     size = root.find('size')
#     w = int(size.find('width').text)
#     h = int(size.find('height').text)
#
#     for obj in root.iter('object'):
#         difficult = obj.find('difficult').text
#         cls = obj.find('name').text
#         if cls not in classes or int(difficult) == 1:
#             continue
#         cls_id = classes.index(cls)
#         xmlbox = obj.find('bndbox')
#         b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
#         bb = convert((w,h), b)
#         out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
#     in_file.close()
#     out_file.close()

wd = os.getcwd()
wd = os.getcwd()
data_base_dir = os.path.join(wd, "VOCdevkit/")
if not os.path.isdir(data_base_dir):
    os.mkdir(data_base_dir)
work_sapce_dir = os.path.join(data_base_dir, "VOC2007/")
if not os.path.isdir(work_sapce_dir):
    os.mkdir(work_sapce_dir)
annotation_dir = os.path.join(work_sapce_dir, "Annotations/")
if not os.path.isdir(annotation_dir):
    os.mkdir(annotation_dir)
clear_hidden_files(annotation_dir)
image_dir = os.path.join(work_sapce_dir, "JPEGImages/")
if not os.path.isdir(image_dir):
    os.mkdir(image_dir)
clear_hidden_files(image_dir)
yolo_labels_dir = os.path.join(work_sapce_dir, "YOLOLabels/")
if not os.path.isdir(yolo_labels_dir):
    os.mkdir(yolo_labels_dir)
clear_hidden_files(yolo_labels_dir)
yolov5_images_dir = os.path.join(data_base_dir, "images/")
if not os.path.isdir(yolov5_images_dir):
    os.mkdir(yolov5_images_dir)
clear_hidden_files(yolov5_images_dir)
yolov5_labels_dir = os.path.join(data_base_dir, "labels/")
if not os.path.isdir(yolov5_labels_dir):
    os.mkdir(yolov5_labels_dir)
clear_hidden_files(yolov5_labels_dir)
yolov5_images_train_dir = os.path.join(yolov5_images_dir, "train/")
if not os.path.isdir(yolov5_images_train_dir):
    os.mkdir(yolov5_images_train_dir)
clear_hidden_files(yolov5_images_train_dir)
yolov5_images_test_dir = os.path.join(yolov5_images_dir, "val/")
if not os.path.isdir(yolov5_images_test_dir):
    os.mkdir(yolov5_images_test_dir)
clear_hidden_files(yolov5_images_test_dir)
yolov5_labels_train_dir = os.path.join(yolov5_labels_dir, "train/")
if not os.path.isdir(yolov5_labels_train_dir):
    os.mkdir(yolov5_labels_train_dir)
clear_hidden_files(yolov5_labels_train_dir)
yolov5_labels_test_dir = os.path.join(yolov5_labels_dir, "val/")
if not os.path.isdir(yolov5_labels_test_dir):
    os.mkdir(yolov5_labels_test_dir)
clear_hidden_files(yolov5_labels_test_dir)

train_file = open(os.path.join(wd, "yolov5_train.txt"), 'w')
test_file = open(os.path.join(wd, "yolov5_val.txt"), 'w')
train_file.close()
test_file.close()
train_file = open(os.path.join(wd, "yolov5_train.txt"), 'a')
test_file = open(os.path.join(wd, "yolov5_val.txt"), 'a')
list_imgs = os.listdir(image_dir)  # list image files
probo = random.randint(1, 100)
print(list_imgs)
print("Probobility: %d" % probo)
for i in range(0, len(list_imgs)):
    path = os.path.join(image_dir, list_imgs[i])
    if os.path.isfile(path):
        image_path = image_dir + list_imgs[i]
        voc_path = list_imgs[i]
        (nameWithoutExtention, extention) = os.path.splitext(os.path.basename(image_path))
        (voc_nameWithoutExtention, voc_extention) = os.path.splitext(os.path.basename(voc_path))
        annotation_name = nameWithoutExtention + '.xml'
        annotation_path = os.path.join(annotation_dir, annotation_name)
        label_name = nameWithoutExtention + '.txt'
        label_path = os.path.join(yolo_labels_dir, label_name)
    probo = random.randint(1, 100)
    print("Probobility: %d" % probo)
    if (probo < 80):  # train dataset
        # if os.path.exists(annotation_path):
            train_file.write(image_path + '\n')
            # convert_annotation(nameWithoutExtention) # convert label
            copyfile(image_path, yolov5_images_train_dir + voc_path)
            copyfile(label_path, yolov5_labels_train_dir + label_name)
    else:  # test dataset
        # if os.path.exists(annotation_path):
            test_file.write(image_path + '\n')
            # convert_annotation(nameWithoutExtention) # convert label
            copyfile(image_path, yolov5_images_test_dir + voc_path)
            copyfile(label_path, yolov5_labels_test_dir + label_name)
train_file.close()
test_file.close()

上面这个结构中,图片和txt文件的名称一定要保证能对上,一个图片就对应有一个对应的相同文件名的txt文件。

运行之后会得到

这个文件目录结构的数据集,至此我们的数据集准备就结束了。

训练配置文件的准备

首先要为自己的训练数据集制作一个mydata.yaml如下

train: C:\Users\1\Desktop\lzy_yolov8\ultralytics\data\images\train
#test: ./mydata/test/image
val: C:\Users\1\Desktop\lzy_yolov8\ultralytics\data\images\val
#如果按照上述方式建立文件夹,则上面train、test和val地址可以不变


# Keypoints
kpt_shape: [7, 3]  # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
flip_idx: [0, 1, 2, 3, 4, 5 , 6]

# 修改你的检测类别
names:
  0: cat

放到yolo原码中ultralytics\ultralytics\cfg\datasets\mydata.yaml这个路径下

修改yolo原码中ultralytics\ultralytics\cfg\default.yaml

至此配置文件就准备好了

开始训练

在终端中输入命令

 yolo pose cfg= ./cfg/default.yaml

等待训练完成会得到一个best.pt模型文件和训练过程中的数据图。

模型验证

# 测试图片
from ultralytics import YOLO
import cv2
import numpy as np
import sys

# 读取命令行参数
weight_path = './best.pt'
media_path = "./b.png"

# 加载模型
model = YOLO(weight_path)

# 获取类别
objs_labels = model.names  # get class labels
print(objs_labels)

# 类别的颜色
class_color = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 0), (0, 255, 0), (0, 0, 255),
               (255, 255, 0), (255, 0, 0), (0, 255, 0)]
# 关键点的顺序
class_list = ["cat"]

# 关键点的颜色
keypoint_color = [(255, 0, 0), (0, 255, 0), (255, 0, 0), (0, 255, 0), (255, 0, 0), (0, 255, 0), (255, 0, 0)]

# 读取图片
frame = cv2.imread(media_path)
frame = cv2.resize(frame, (frame.shape[1] // 2, frame.shape[0] // 2))
# rotate
# 检测
result = list(model(frame, conf=0.5, stream=True))[0]  # inference,如果stream=False,返回的是一个列表,如果stream=True,返回的是一个生成器
boxes = result.boxes  # Boxes object for bbox outputs
boxes = boxes.cpu().numpy()  # convert to numpy array

# 遍历每个框
for box in boxes.data:
    l, t, r, b = box[:4].astype(np.int32)  # left, top, right, bottom
    conf, id = box[4:]  # confidence, class
    id = int(id)
    # 绘制框
    cv2.rectangle(frame, (l, t), (r, b), (0, 0, 255), 2)
    # 绘制类别+置信度(格式:98.1%)
    cv2.putText(frame, f"{objs_labels[id]} {conf * 100:.1f}", (l, t - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
                (0, 0, 255), 1)

# 遍历keypoints
keypoints = result.keypoints  # Keypoints object for pose outputs
keypoints = keypoints.cpu().numpy()  # convert to numpy array

# draw keypoints, set first keypoint is red, second is blue
for keypoint in keypoints.data:
    for i in range(len(keypoint)):
        x, y, _ = keypoint[i]
        x, y = int(x), int(y)
        cv2.circle(frame, (x, y), 3, (0, 255, 0), -1)
        # cv2.putText(frame, f"{keypoint_list[i]}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, keypoint_color[i], 2)

    if len(keypoint) >= 2:
        # draw arrow line from tail to half between head and tail
        x0, y0, _ = keypoint[0]
        x1, y1, _ = keypoint[1]
        x2, y2, _ = keypoint[2]
        x3, y3, _ = keypoint[3]
        x4, y4, _ = keypoint[4]
        x5, y5, _ = keypoint[5]
        x6, y6, _ = keypoint[6]

        cv2.line(frame, (int(x0), int(y0)), (int(x1), int(y1)), (255, 0, 255), 1)
        cv2.line(frame, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 255), 1)
        cv2.line(frame, (int(x2), int(y2)), (int(x3), int(y3)), (255, 0, 255), 1)
        cv2.line(frame, (int(x3), int(y3)), (int(x4), int(y4)), (255, 0, 255), 1)
        cv2.line(frame, (int(x4), int(y4)), (int(x5), int(y5)), (255, 0, 255), 1)
        cv2.line(frame, (int(x5), int(y5)), (int(x6), int(y6)), (255, 0, 255), 1)

        # center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
    # cv2.arrowedLine(frame, (int(x2), int(y2)), (int(center_x), int(center_y)), (255, 0, 255), 4,
    #                line_type=cv2.LINE_AA, tipLength=0.1)

# save image
cv2.imwrite("result.jpg", frame)
print("save result.jpg")
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值