深度学习——COCO全身关键点提取部分指定的关键点

使用yolov8训练人体关键点模型;
一个模型多个类别,不同类别关键点个数不一致;
我目前了解到的好像只有COCO是有全身关键点;
COCO全身关键点链接:https://github.com/jin-s13/COCO-WholeBody
在这里插入图片描述
以下代码能从COCO全身标注的json中提取出来想要的关键点和对应的类别;并且直接转换成了yolov8可用的训练txt格式,
注意:其中最后一行图片搬运使用的os.link,类似Linux中的硬链接,并非复制,如果内存充足的情况下可以使用shutil.copy替换;

20240408-测试版本代码

# -*- coding: UTF-8 -*-
"""
@Project :ultralytics 
@IDE     :PyCharm 
@Author  :沐枫
@Date    :2024/4/8 15:11

单线程处理,
因为技术不精,想改写成多线程发现代码速度并未提高所以就没写多线程版本代码;

COCO-WholeBody下载链接:https://github.com/jin-s13/COCO-WholeBody
"""
import os
import json
import shutil
from typing import Dict, List

from tqdm import tqdm
import cv2

COCO_URL_ROOT = "http://images.cocodataset.org"


class DecodeWholeBodyImage:
    """
    解析图片的字典信息
    """

    def __init__(self, image_info: Dict):
        self.license = image_info['license']
        self.date_captured = image_info['date_captured']
        self.flickr_url = image_info['flickr_url']

        self.id = image_info['id']
        self.image_id = image_info['id']  # 和annotation中的image_id一样,对应到一起可以找到对应的目标
        self.file_name = image_info['file_name']

        # 'http://images.cocodataset.org/val2017/000000397133.jpg'
        self.coco_url = image_info['coco_url']
        self.height = image_info['height']
        self.width = image_info['width']

        if 'http' not in self.flickr_url:
            self.url = self.coco_url
        else:
            self.url = self.flickr_url


class DecodeWholeBodyAnnotation:
    """
    一个目标的信息解析
    边界框格式是ltwh
    """

    def __init__(self, annotation: Dict):
        # 通过这个id找图片
        self.image_id = annotation['image_id']
        # 是否是人群,0:不是
        self.iscrowd = annotation['iscrowd']
        # 分割
        self.segmentation = annotation['segmentation']
        # 目标的id
        self.id = annotation['id']
        # 目标的类别索引
        self.category_id = annotation['category_id']

        # 身体关键点和box
        self.body_points = annotation['keypoints']
        self.body_box = annotation['bbox']
        self.num_keypoints = annotation['num_keypoints']  # 关键点有效个数

        # 脚关键点
        self.foot_points = annotation['foot_kpts']
        self.foot_valid = annotation['foot_valid']  # 脚关键点的有效性

        # 脸的关键点和box
        self.face_points = annotation['face_kpts']
        self.face_box = annotation['face_box']
        self.face_valid = annotation['face_valid']  # 有效性

        # left手关键点和box
        self.lefthand_box = annotation['lefthand_box']
        self.lefthand_points = annotation['lefthand_kpts']
        self.lefthand_valid = annotation['lefthand_valid']  # 有效性

        # right关键点和box
        self.righthand_box = annotation['righthand_box']
        self.righthand_points = annotation['righthand_kpts']
        self.righthand_valid = annotation['righthand_valid']  # 有效性

        # 把所有的关键点整合到一起
        self.all_points = list()
        self.all_points.extend(self.body_points)
        self.all_points.extend(self.foot_points)
        self.all_points.extend(self.face_points)
        self.all_points.extend(self.lefthand_points)
        self.all_points.extend(self.righthand_points)


def clip(value, min_v, max_v):
    if value < min_v:
        value = min_v

    if value > max_v:
        value = max_v

    return value


def ltwh2xywhn(bbox, img_h, img_w):
    """
    输入是COCO格式的box是ltwh,输出是归一化之后的xywhn,可以利用来训练yolo模型
    Args:
        bbox: ltwh
        img_h:
        img_w:

    Returns:

    """
    x1, y1, w, h = bbox  # ltwh

    x1 = clip(x1, 0, img_w)
    y1 = clip(y1, 0, img_h)
    x2 = clip(x1 + w, 0, img_w)
    y2 = clip(y1 + h, 0, img_h)

    w = x2 - x1
    h = y2 - y1

    # 计算box中心点坐标
    x = x1 + w / 2
    y = y1 + h / 2

    # 归一化
    x = x / img_w
    y = y / img_h
    w = w / img_w
    h = h / img_h

    return x, y, w, h


def get_point(point_index, all_points, img_shape_wh=None, max_point_num=0):
    """
    根据关键点索引从关键点list中找到对应的关键点并进行归一化后转成字符串格式,返回回去
    Args:
        point_index: 想要的关键点的索引
        all_points: 所有关键点的list
        img_shape_wh: (w, h),入股哦是None,就不归一化
        max_point_num: 关键点最多的个数

    Returns: str

    """
    current_point_num = len(point_index)
    # 保存结果的字符串
    res = ""
    if current_point_num > 0:
        # 先根据索引获取到想要的关键点
        for index in point_index:
            start = index * 3
            end = (index + 1) * 3

            x, y, v = all_points[start:end]
            # 对可视信息调整
            if 0 < v <= 1:
                v = 1
            if 1 < v <= 2:
                v = 2

            # 是否归一化
            if img_shape_wh is not None:
                img_w, img_h = img_shape_wh
                x = clip(x, 0, img_w) / img_w
                y = clip(y, 0, img_h) / img_h

            res += f"{x:.6f} {y:.6f} {int(v)} "

        # 如果关键点比较少,就使用全0填充
        if current_point_num < max_point_num:
            _temp = " ".join((["0"] * (max_point_num - current_point_num) * 3))
            res += _temp

    else:  # 没有指定关键点索引,使用全0代替
        _temp = " ".join((["0"] * MAX_POINT_NUM * 3))
        res += _temp

    return res.strip()


if __name__ == '__main__':
    data_root = r"Z:\Datasets\Detection\COCO2017"
    if data_root == "":
        raise ValueError(f"{data_root} should not be empty string")
    data_root = os.path.abspath(data_root)

    # 项目名称
    project = "FallAndSit"
    # 规定想保留的目标
    # cls_index指的是类别索引
    # box_type指的是该类别的边界框类型,
    # body_box指的是人体的边界框;face_box指的是人脸边界框;lefthand_box指的是左手边界框;righthand_box指的是右手边界框
    # point_index指的是该类别的关键点索引,整体的索引,会按照顺序取关键点
    BOX_TYPE = ("body_box", "face_box", "lefthand_box", "righthand_box",)
    POINT_INDEX_MAX = 129
    Object_info: List[Dict] = [
        {
            "cls_index": 0,  # 指定该目标的类别索引
            "box_type": "body_box",  # 指定该目标使用哪个框
            "point_index": (6, 5, 12, 11, 14, 13, 16, 15),  # 指定关键点的索引
        },

        {"cls_index": 1,
         "box_type": "face_box",
         "point_index": (2, 1, 4, 3, 71, 77, 53, 31)},

        # {"cls_index": 1,
        #  "box_type": "face_box",
        #  "point_index": tuple()},
    ]
    # 关键点最多的数量,用来对齐关键点的数量,如果不够的使用[0, 0, 0]填充
    MAX_POINT_NUM = 0
    for value in Object_info:
        MAX_POINT_NUM = max(MAX_POINT_NUM, len(value["point_index"]))

    if len(Object_info) == 0:
        raise ValueError("Object_dict is empty")

    image_root = os.path.join(data_root, project, "images")
    txt_root = os.path.join(data_root, project, "labels")

    if os.path.exists(image_root):
        shutil.rmtree(image_root)
    os.makedirs(image_root)
    if os.path.exists(txt_root):
        shutil.rmtree(txt_root)
    os.makedirs(txt_root)

    json_path_list = [
        os.path.join(data_root, "annotations", "coco-wholebody", "coco_wholebody_val_v1.0.json"),
        # os.path.join(data_root, "annotations", "coco-wholebody", "coco_wholebody_train_v1.0.json"),
    ]

    for json_path in json_path_list:
        # 保存数据
        information = dict()

        print(f"read {json_path}")
        # 读文件
        with open(json_path, 'r', encoding="utf-8") as rFile:
            json_data = json.load(rFile)
        print(f"read {json_path} finish ...")

        # 先处理图片
        print(f"deal images ...")
        # list:[dict ...]
        image_list = json_data['images']

        for step in tqdm(range(len(image_list)), desc=f"deal {os.path.basename(json_path)}"):
            # 下面这些可以写成一个函数,使用多线程处理
            img_info = DecodeWholeBodyImage(image_list[step])

            # 图片路径img_info.coco_url:'http://images.cocodataset.org/val2017/000000397133.jpg'
            # 原图路径
            img_path = os.path.join(data_root,
                                    img_info.coco_url.replace(COCO_URL_ROOT, "images").replace("/", os.sep))

            img = cv2.imread(img_path)
            if img is None:
                continue
            h, w = img.shape[:2]

            dst_img_path = img_path.replace(os.path.join(data_root, "images"), image_root)
            information[img_info.id] = {
                "file_name": img_info.file_name,  # 图片名称
                'h': h,  # 图片的高
                'w': w,  # 图片的宽
                "src_path": img_path,  # 原图路径
                "dst_path": dst_img_path,  # 该项目中目标路径
            }

        print("deal image information finish ...")
        # 收集好图片的信息之后,开始收集目标的信息
        print("deal annotation ...")

        annotations = json_data['annotations']
        for step in tqdm(range(len(annotations)), desc=f"deal {os.path.basename(json_path)}"):
            # 解析目标
            annotation = DecodeWholeBodyAnnotation(annotations[step])

            # 获取目标对应的图片的信息
            image_info = information[annotation.image_id]
            # 图片名
            file_name = image_info["file_name"]
            # 后缀
            _, suffix = os.path.splitext(file_name)
            # 原图路径
            src_image_path = image_info["src_path"]
            # 目标图路径
            dst_image_path = image_info["dst_path"]
            # 标签保存路径
            txt_path = dst_image_path.replace(image_root, txt_root).replace(suffix, ".txt")

            # 图片的宽高
            img_h = image_info['h']
            img_w = image_info['w']

            # 开始获取想要的关键点和目标
            results = list()
            for value in Object_info:
                cls_index = value["cls_index"]
                box_type = value["box_type"]
                assert box_type in BOX_TYPE, f"{box_type} not in {BOX_TYPE}"

                # 目标字符串
                res = ""
                if box_type == "body_box" and (not annotation.iscrowd):  # 不是人群,大密集的
                    box = ltwh2xywhn(annotation.body_box, img_h=img_h, img_w=img_w)
                    res += f"{cls_index} {box[0]:.6f} {box[1]:.6f} {box[2]:.6f} {box[3]:.6f} "

                    # 关键点的索引tuple
                    point_index = value["point_index"]
                    # 关键点字符串
                    res += get_point(point_index=point_index,
                                     all_points=annotation.all_points,
                                     img_shape_wh=(img_w, img_h),
                                     max_point_num=MAX_POINT_NUM
                                     )

                elif box_type == "face_box" and annotation.face_valid:
                    box = ltwh2xywhn(annotation.face_box, img_h=img_h, img_w=img_w)
                    res += f"{cls_index} {box[0]:.6f} {box[1]:.6f} {box[2]:.6f} {box[3]:.6f} "

                    # 关键点的索引tuple
                    point_index = value["point_index"]
                    # 关键点字符串
                    res += get_point(point_index=point_index,
                                     all_points=annotation.all_points,
                                     img_shape_wh=(img_w, img_h),
                                     max_point_num=MAX_POINT_NUM
                                     )

                elif box_type == "lefthand_box" and annotation.lefthand_valid:
                    box = ltwh2xywhn(annotation.lefthand_box, img_h=img_h, img_w=img_w)
                    res += f"{cls_index} {box[0]:.6f} {box[1]:.6f} {box[2]:.6f} {box[3]:.6f} "

                    # 关键点的索引tuple
                    point_index = value["point_index"]
                    # 关键点字符串
                    res += get_point(point_index=point_index,
                                     all_points=annotation.all_points,
                                     img_shape_wh=(img_w, img_h),
                                     max_point_num=MAX_POINT_NUM
                                     )

                elif box_type == "righthand_box" and annotation.lefthand_valid:
                    box = ltwh2xywhn(annotation.righthand_box, img_h=img_h, img_w=img_w)
                    res += f"{cls_index} {box[0]:.6f} {box[1]:.6f} {box[2]:.6f} {box[3]:.6f} "

                    # 关键点的索引tuple
                    point_index = value["point_index"]
                    # 关键点字符串
                    res += get_point(point_index=point_index,
                                     all_points=annotation.all_points,
                                     img_shape_wh=(img_w, img_h),
                                     max_point_num=MAX_POINT_NUM,
                                     )

                # 如果当前
                if res != "":
                    results.append(res)

            os.makedirs(os.path.dirname(txt_path), exist_ok=True)
            with open(txt_path, "a", encoding="utf-8") as wFile:
                for line in results:
                    wFile.write(f"{line}\n")

            # 映射图片
            if not os.path.exists(dst_image_path):
                os.makedirs(os.path.dirname(dst_image_path), exist_ok=True)
                # 图片使用硬链接
                os.link(src_image_path, dst_image_path)
                # 直接复制
                # shutil.copy(src_image_path, dst_image_path)


示例:【人脸7个关键点,身体8个关键点】
在这里插入图片描述
在这里插入图片描述

可视化代码

# -*- coding: UTF-8 -*-
"""
@Project :ultralytics 
@IDE     :PyCharm 
@Author  :沐枫
@Date    :2024/3/21 16:24 
"""
import os
import cv2
import numpy as np

# 数据集根目录
data_root = r"coco2017"
data_root = os.path.abspath(data_root)

image_root = os.path.join(data_root, "images", "val2017")
txt_root = os.path.join(data_root, "labels", "val2017")

count = 0
for root, _, files in os.walk(txt_root):
    for file in files:
        if count >= 100:
            break
        image_name, suffix = os.path.splitext(file)

        txt_path = os.path.join(root, file)
        image_path = txt_path.replace(txt_root, image_root).replace(suffix, ".jpg")

        image = cv2.imread(image_path)
        labels = np.loadtxt(txt_path)
        if labels.ndim < 2:
            labels = labels[None, ...]
        if len(labels) == 0:
            continue

        img_h, img_w = image.shape[:2]

        bboxes = labels[..., 1:5] * [img_w, img_h, img_w, img_h, ]
        # NOTE:因为只有8个关键点
        kpt_num = len(labels[0][5:]) // 3
        if len(labels[0][5:]) % 3 != 0:
            # 就算是这个目标没有指定关键点,使用代码COCOWholeBodyPoints.py生成的txt应该全是0,不应该没有数据
            raise ValueError("len(labels[..., 5:]) should equal kpt_num * 3, "
                             "len(labels[..., 5:]) % 3 remainder should be 0.")

        kpts = labels[..., 5:].reshape(-1, kpt_num, 3) * [img_w, img_h, 1]

        for box in np.array(bboxes, dtype=np.int32):
            x, y, w, h = box
            x1 = x - w // 2
            y1 = y - h // 2
            x2 = x1 + w
            y2 = y1 + h
            cv2.rectangle(image, pt1=(x1, y1), pt2=(x2, y2), color=(0, 255, 255), thickness=1)

        for kpt in np.array(kpts, dtype=np.int32):
            for i, (x, y, v) in enumerate(kpt):
                if v == 0:
                    continue
                cv2.circle(image, center=(int(x), int(y)), radius=5, color=(255, 0, 255), thickness=-1, )
                cv2.putText(image, text=f"{i}", org=(int(x) + 6, int(y) + 6), color=(255, 0, 255),
                            fontFace=1, fontScale=1.5, thickness=2)

        save_path = f"vis/{image_name}.jpg"
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        cv2.imwrite(save_path, image)
        print(image_path)
        count += 1

  • 14
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值