hrnet训练的pt模型结合目标检测进行关键点识别的更准确前向推理

本篇在将图像输入hrnet识别之前先进行目标检测来确定识别的位置,让识别更加精准。

本段代码设置了一个区域框BOX,让人走入区域内才开始检测,适用于考核等场景,也可以直接去掉BOX也是一样的效果。若画面背景中有多个行人,还是只取要检测的那个人,同理还是适用考核场景。
为了让检测效果更直观,在一些点位直接使用线连接起来模拟人体骨骼。

import os
import sys
import numpy as np
from mmpose.apis import init_model, inference_topdown
import cv2

import torch
sys.path.append("/home/yons/train/code/yolov5")
from models.experimental import attempt_load
from utils.general import non_max_suppression, scale_boxes
from torchvision import transforms

# 配置文件路径和检查点文件路径
config_file = '/home/.../pose_td-hm_hrnet-w48_8xb32-210e_PullUp-det/pose_td-hm_hrnet-w48_8xb32-210e_PullUp-det.py'
checkpoint_file = '/home/.../pose_td-hm_hrnet-w48_8xb32-210e_PullUp-det/best_coco_AP_epoch_250.pth'

# 初始化姿态估计模型
pose_model = init_model(config_file, checkpoint_file, device='cuda:0')
yolo_model = attempt_load('yolov5x6.pt', device='cuda:0')  # 加载训练好的yolov5模型
pose_model.eval()
yolo_model.eval()

VIDEO_PATH = 'input.mp4' 
BOX = (300, 50, 300, 450)  # 区域框的左上角坐标和宽高
OUTPUT_VIDEO_PATH = 'output.mp4'

def draw_keypoints(frame, keypoints, box, det_box):
    # 在帧上绘制关键点
    # 这里假设关键点是一个 Nx2 的数组,其中 N 是关键点的数量
    # 并且关键点的坐标是相对于裁剪区域的
    x, y, w, h = box
    for kp in keypoints:
        kp_x, kp_y = kp
        x_rec1 = int(det_box[0] + x)
        y_rec1 = int(det_box[1] + y)
        x_rec2 = int(det_box[2] + x)
        y_rec2 = int(det_box[3] + y)
        cv2.rectangle(frame, (x_rec1, y_rec1), (x_rec2, y_rec2), (0, 0, 255), 2)
        x_cir = int(kp_x + det_box[0] + x)
        y_cir = int(kp_y + det_box[1] + y)
        cv2.circle(frame, (x_cir, y_cir), 3, (0, 255, 0), -1)
    lines = [(15, 13), (13, 11), (16, 14), (14, 12), (11, 12), (5, 11), (6, 12), (5, 6),
             (5, 7), (6, 8), (7, 9), (8, 10), (1, 2), (0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6)]
    for line in lines:
        pt1 = (int(keypoints[line[0]][0] + det_box[0] + x), int(keypoints[line[0]][1] + det_box[1] + y))
        pt2 = (int(keypoints[line[1]][0] + det_box[0] + x), int(keypoints[line[1]][1] + det_box[1] + y))
        cv2.line(frame, pt1, pt2, (0, 255, 0), 2)

def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
    """Resizes and pads image to new_shape with stride-multiple constraints, returns resized image, ratio, padding."""
    shape = im.shape[:2]  # current shape [height, width]
    if isinstance(new_shape, int):
        new_shape = (new_shape, new_shape)

    # Scale ratio (new / old)
    r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
    if not scaleup:  # only scale down, do not scale up (for better val mAP)
        r = min(r, 1.0)

    # Compute padding
    ratio = r, r  # width, height ratios
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding
    if auto:  # minimum rectangle
        dw, dh = np.mod(dw, stride), np.mod(dh, stride)  # wh padding
    elif scaleFill:  # stretch
        dw, dh = 0.0, 0.0
        new_unpad = (new_shape[1], new_shape[0])
        ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]  # width, height ratios

    dw /= 2  # divide padding into 2 sides
    dh /= 2

    if shape[::-1] != new_unpad:  # resize
        im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
        print(im.shape)

    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))

    im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border

    return im, ratio, (dw, dh)

# 处理每一张图像
k = 0
if __name__ == '__main__':
    # 打开视频
    cap = cv2.VideoCapture(VIDEO_PATH)

    # 获取视频的一些属性
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # 创建 VideoWriter 对象
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # 或者使用 'XVID'
    out = cv2.VideoWriter(OUTPUT_VIDEO_PATH, fourcc, fps, (width, height))

    while True:
        # 读取一帧
        ret, frame = cap.read()
        if not ret:
            break

        # 加载帧
        x, y, w, h = BOX
        img0 = frame[y:y+h, x:x+w, :]

        img_size = (1280, 1280)
        stride = max(int(yolo_model.stride.max()), 32)
        img = letterbox(img0, img_size, stride=stride, auto=True)[0]   # padded resize
        img = img.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
        img = np.ascontiguousarray(img)  # contiguous

        # yolo_model.warmup(imgsz=(1, 3, *img_size))  # warmup
        img = torch.from_numpy(img).to('cuda:0')
        # img = img.half() if yolo_model.fp16 else img.float()  # uint8 to fp16/32
        img = img.float()
        img /= 255  # 0 - 255 to 0.0 - 1.0
        if len(img.shape) == 3:
            img = img[None]  # expand for batch dim

        with torch.no_grad():
            pred = yolo_model(img)  # Inference
        pred = non_max_suppression(pred, 0.25, 0.45, 0)  # NMS    0 for person
        # input()

        det_box = None
        # Process predictions
        # print(pred)
        for i, det in enumerate(pred):  # per image
            # print(det)
            if len(det):
                # Rescale boxes from img_size to img0 size
                det[:, :4] = scale_boxes(img.shape[2:], det[:, :4], img0.shape).round()

                max_bb = None
                for x1, y1, x2, y2, conf, cls in reversed(det):
                    if max_bb is None:
                        max_bb = [x1, y1, x2, y2]
                    else:
                        if ((x2 - x1) * (y2 - y1)) > (
                                (max_bb[2] - max_bb[0]) * (max_bb[3] - max_bb[3])):
                            max_bb = [x1, y1, x2, y2]
                det_box = max_bb
                for idx in range(len(det_box)):
                    det_box[idx] = int(det_box[idx])

        x1, y1, x2, y2 = det_box
        # print(det_box)
        img_seg = img0[y1:y2, x1:x2, :]

        person_results = np.array([[0, 0, x2-x1, y2-y1]])
        # 推理得到关键点坐标
        pose_results = inference_topdown(pose_model, img_seg, person_results, bbox_format='xyxy')

        # 提取关键点坐标并检查是否检测出17个关键点
        keypoints = []
        if len(pose_results) > 0 and pose_results[0].pred_instances.keypoints.shape[1] == 17:
            keypoints = pose_results[0].pred_instances.keypoints[0]

        draw_keypoints(frame, keypoints, BOX, det_box)

        # 写入帧
        out.write(frame)

        # 显示帧
        cv2.imshow('frame', frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    out.release()
    cv2.destroyAllWindows()

input.mp4视频如下:

引体向上原始视频


output.mp4视频如下:

引体向上推理结果视频

大概原理是区域框内进行一系列处理后输入进yolo进行目标检测,在多个目标框内选出我们要检测的人物的目标框输入进hrnet得到关键点,关键点从目标框映射回区域框再映射回原图,得到最终结果。

训练自己的数据集进行HRNet关键点检测,可以按照以下步骤进行操作: 1. 准备数据集:首先,需要准备自己的数据集,包括图像和对应的关键点标注。可以使用现有的数据集或者自己创建一个新的数据集。 2. 数据预处理:对数据集进行预处理,包括图像的缩放、裁剪、归一化等操作,以及关键点的坐标转换等。可以参考HRNet源码中的数据预处理部分,根据具体需求进行相应的处理。 3. 修改配置文件:在HRNet源码中,可以找到相应的配置文件,例如`experiments/pose/coco/hrnet/w32_256x192_adam_lr1e-3.yaml`。可以根据自己的数据集和训练需求修改配置文件中的相关参数,比如数据集路径、训练epoch数、学习率等。 4. 训练模型:使用修改后的配置文件进行模型训练。可以运行HRNet训练脚本,例如`tools/train.py`,并指定修改后的配置文件作为参数进行训练。 5. 模型评估与调优:训练完成后,可以使用自己的数据集进行模型评估,比如计算关键点的精度、平均准确度等指标。根据评估结果,可以进行模型调优,如调整网络结构、增加训练数据量、调整超参数等。 6. 导出模型:最后,可以导出训练好的模型,以便在实际应用中使用。可以使用HRNet提供的导出模型的脚本,例如`tools/valid.py`,并指定训练好的模型路径进行导出。 通过以上步骤,就可以使用HRNet对自己的数据集进行关键点检测训练,并得到相应的模型。请注意,具体的操作细节可能会根据实际情况有所不同,请参考相关文档和源码进行具体操作。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* *3* [关键点检测一:HRNet数据预处理(MPII)](https://blog.csdn.net/qq_43312130/article/details/122034420)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 100%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值