如何利用模型自动给人体数据集打标签

手动给数据集打标签的方法在这里就不介绍了,那么如何利用模型自动给人体数据集打标签呢?
最近阅读了DWPose代码,测试了其中的demo中的topdown_demo_with_mmdet代码,代码仅支持一次推理一张图片或视频,并可以导出其关键点数据到json文件。
相关运行指令请参考MMPose demo 脚本部分。仅需替换pose_config文件以及pose_checkpoint模型到最新版本,我用的是rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py和dw-ll_ucoco_384.pth。
经过测试是可以的,但只能做单张推理,因此,魔改了代码,使其可以进行批量处理,分别按路径保存图片和标签文件。并剔除json文件中无用部分。

import mimetypes
import os
import time
from argparse import ArgumentParser
from tqdm import tqdm
import cv2
import json_tricks as json
import mmcv
import mmengine
import numpy as np

from mmpose.apis import inference_topdown
from mmpose.apis import init_model as init_pose_estimator
from mmpose.evaluation.functional import nms
from mmpose.registry import VISUALIZERS
from mmpose.structures import merge_data_samples, split_instances
from mmpose.utils import adapt_mmdet_pipeline
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="json_tricks")

try:
    from mmdet.apis import inference_detector, init_detector
    has_mmdet = True
except (ImportError, ModuleNotFoundError):
    has_mmdet = False



def process_one_image(args, img_path, detector, pose_estimator, visualizer=None):
    img = mmcv.imread(img_path)
    # predict bbox
    det_result = inference_detector(detector, img_path)
    pred_instance = det_result.pred_instances.cpu().numpy()
    bboxes = np.concatenate(
        (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
    bboxes = bboxes[np.logical_and(pred_instance.labels == args.det_cat_id, pred_instance.scores > args.bbox_thr)]
    bboxes = bboxes[nms(bboxes, args.nms_thr), :4]

    # predict keypoints
    pose_results = inference_topdown(pose_estimator, img_path, bboxes)
    data_samples = merge_data_samples(pose_results)

    if visualizer is not None:
        visualizer.add_datasample(
            'result',
            img,
            data_sample=data_samples,
            draw_gt=False,
            draw_heatmap=args.draw_heatmap,
            draw_bbox=args.draw_bbox,
            show_kpt_idx=args.show_kpt_idx,
            skeleton_style=args.skeleton_style,
            show=args.show,
            kpt_thr=args.kpt_thr
        )

    return data_samples.get('pred_instances', None)


def process_images(args, img_folder, detector, pose_estimator, visualizer=None):
    # 文件价循环,文件名是1-99
    subfolders = [os.path.join(img_folder, str(i)) for i in range(1, 100)]
    image_paths = []
    
    for subfolder in subfolders:
    # 每个文件里有什么文件名就设置
        for hand_type in ['aaa', 'bbb']:
            hand_folder = os.path.join(subfolder, hand_type)
            image_paths.extend([os.path.join(hand_folder, f) for f in os.listdir(hand_folder) if f.endswith(('.png', '.jpg', '.jpeg'))])

    all_pred_instances = []

    for img_path in tqdm(image_paths, desc="Processing images"):
        pred_instances = process_one_image(args, img_path, detector, pose_estimator, visualizer)

        # 获取相对路径
        relative_path = os.path.relpath(img_path, img_folder)

        if args.save_predictions:
            pred_instances_list = split_instances(pred_instances)
            single_result = {
                "img_path": img_path,
                "instances": pred_instances_list
            }
            all_pred_instances.append(single_result)

            # Save results to individual JSON files
            json_filename = os.path.splitext(relative_path)[0] + ".json"
            json_filepath = os.path.join(args.json_output_root, json_filename)
            
            # 创建json没有的目录
            os.makedirs(os.path.dirname(json_filepath), exist_ok=True)
            with open(json_filepath, 'w') as f:
                json.dump(single_result, f, indent='\t')

        if args.img_output_root:
            img_vis = visualizer.get_image()
            img_output_path = os.path.join(args.img_output_root, relative_path)

            # 创建img没有的目录
            os.makedirs(os.path.dirname(img_output_path), exist_ok=True)
            mmcv.imwrite(img_vis, img_output_path)

    return all_pred_instances



def main():
    """Visualize the demo images.

    Using mmdet to detect the human.
    """
    parser = ArgumentParser()
    parser.add_argument('det_config', help='Config file for detection')
    parser.add_argument('det_checkpoint', help='Checkpoint file for detection')
    parser.add_argument('pose_config', help='Config file for pose')
    parser.add_argument('pose_checkpoint', help='Checkpoint file for pose')
    parser.add_argument('--img-folder', type=str, default='images', help='Folder containing multiple images for processing')
    parser.add_argument('--input', type=str, default='', help='Image/Video file')
    parser.add_argument('--img-output-root', type=str, default='img_results', help='Directory to save the visualized images.')
    parser.add_argument('--json-output-root', type=str, default='json_results', help='Directory to save the JSON results.')
    parser.add_argument(
        '--show',
        action='store_true',
        default=False,
        help='whether to show img')
    parser.add_argument(
        '--save-predictions',
        action='store_true',
        default=True,
        help='whether to save predicted results')
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference')
    parser.add_argument(
        '--det-cat-id',
        type=int,
        default=0,
        help='Category id for bounding box detection model')
    parser.add_argument(
        '--bbox-thr',
        type=float,
        default=0.3,
        help='Bounding box score threshold')
    parser.add_argument(
        '--nms-thr',
        type=float,
        default=0.3,
        help='IoU threshold for bounding box NMS')
    parser.add_argument(
        '--kpt-thr',
        type=float,
        default=0.3,
        help='Visualizing keypoint thresholds')
    parser.add_argument(
        '--draw-heatmap',
        action='store_true',
        default=False,
        help='Draw heatmap predicted by the model')
    parser.add_argument(
        '--show-kpt-idx',
        action='store_true',
        default=False,
        help='Whether to show the index of keypoints')
    parser.add_argument(
        '--skeleton-style',
        default='mmpose',
        type=str,
        choices=['mmpose', 'openpose'],
        help='Skeleton style selection')
    parser.add_argument(
        '--radius',
        type=int,
        default=3,
        help='Keypoint radius for visualization')
    parser.add_argument(
        '--thickness',
        type=int,
        default=1,
        help='Link thickness for visualization')
    parser.add_argument(
        '--show-interval', type=int, default=0, help='Sleep seconds per frame')
    parser.add_argument(
        '--alpha', type=float, default=0.8, help='The transparency of bboxes')
    parser.add_argument(
        '--draw-bbox', action='store_true', help='Draw bboxes of instances')

    assert has_mmdet, 'Please install mmdet to run the demo.'
    args = parser.parse_args()
    assert args.img_folder, "Please specify the img-folder argument."

    assert args.det_config is not None
    assert args.det_checkpoint is not None

    # build detector
    detector = init_detector(
        args.det_config, args.det_checkpoint, device=args.device)
    detector.cfg = adapt_mmdet_pipeline(detector.cfg)

    # build pose estimator
    pose_estimator = init_pose_estimator(
        args.pose_config,
        args.pose_checkpoint,
        device=args.device,
        cfg_options=dict(
            model=dict(test_cfg=dict(output_heatmaps=args.draw_heatmap))))

    # build visualizer
    pose_estimator.cfg.visualizer.radius = args.radius
    pose_estimator.cfg.visualizer.alpha = args.alpha
    pose_estimator.cfg.visualizer.line_width = args.thickness
    visualizer = VISUALIZERS.build(pose_estimator.cfg.visualizer)
    visualizer.set_dataset_meta(
        pose_estimator.dataset_meta, skeleton_style=args.skeleton_style)

    pred_instances_list = process_images(args, args.img_folder, detector, pose_estimator, visualizer)


if __name__ == '__main__':
    main()

经过测试,完美运行,当然自己用的话要修改一些路径以及修改下文件循环那部分,因为我是按照自身需求写的。别的地方不用修改。导出格式是按照coco-wholebody格式输出的,含有133个关节点坐标及置信度。输出图片为标记好关节点的数据集,可以不输出,输出是为了检查自动打标签的质量。也可以用输出的json文件映射到那张图片检查关节点以及boundingbox效果,有时间我更新下check脚本。

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

没了海绵宝宝的派大星

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值