手动给数据集打标签的方法在这里就不介绍了,那么如何利用模型自动给人体数据集打标签呢?
最近阅读了DWPose代码,测试了其中的demo中的topdown_demo_with_mmdet代码,代码仅支持一次推理一张图片或视频,并可以导出其关键点数据到json文件。
相关运行指令请参考MMPose demo 脚本部分。仅需替换pose_config文件以及pose_checkpoint模型到最新版本,我用的是rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py和dw-ll_ucoco_384.pth。
经过测试是可以的,但只能做单张推理,因此,魔改了代码,使其可以进行批量处理,分别按路径保存图片和标签文件。并剔除json文件中无用部分。
import mimetypes
import os
import time
from argparse import ArgumentParser
from tqdm import tqdm
import cv2
import json_tricks as json
import mmcv
import mmengine
import numpy as np
from mmpose.apis import inference_topdown
from mmpose.apis import init_model as init_pose_estimator
from mmpose.evaluation.functional import nms
from mmpose.registry import VISUALIZERS
from mmpose.structures import merge_data_samples, split_instances
from mmpose.utils import adapt_mmdet_pipeline
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="json_tricks")
try:
from mmdet.apis import inference_detector, init_detector
has_mmdet = True
except (ImportError, ModuleNotFoundError):
has_mmdet = False
def process_one_image(args, img_path, detector, pose_estimator, visualizer=None):
img = mmcv.imread(img_path)
# predict bbox
det_result = inference_detector(detector, img_path)
pred_instance = det_result.pred_instances.cpu().numpy()
bboxes = np.concatenate(
(pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
bboxes = bboxes[np.logical_and(pred_instance.labels == args.det_cat_id, pred_instance.scores > args.bbox_thr)]
bboxes = bboxes[nms(bboxes, args.nms_thr), :4]
# predict keypoints
pose_results = inference_topdown(pose_estimator, img_path, bboxes)
data_samples = merge_data_samples(pose_results)
if visualizer is not None:
visualizer.add_datasample(
'result',
img,
data_sample=data_samples,
draw_gt=False,
draw_heatmap=args.draw_heatmap,
draw_bbox=args.draw_bbox,
show_kpt_idx=args.show_kpt_idx,
skeleton_style=args.skeleton_style,
show=args.show,
kpt_thr=args.kpt_thr
)
return data_samples.get('pred_instances', None)
def process_images(args, img_folder, detector, pose_estimator, visualizer=None):
# 文件价循环,文件名是1-99
subfolders = [os.path.join(img_folder, str(i)) for i in range(1, 100)]
image_paths = []
for subfolder in subfolders:
# 每个文件里有什么文件名就设置
for hand_type in ['aaa', 'bbb']:
hand_folder = os.path.join(subfolder, hand_type)
image_paths.extend([os.path.join(hand_folder, f) for f in os.listdir(hand_folder) if f.endswith(('.png', '.jpg', '.jpeg'))])
all_pred_instances = []
for img_path in tqdm(image_paths, desc="Processing images"):
pred_instances = process_one_image(args, img_path, detector, pose_estimator, visualizer)
# 获取相对路径
relative_path = os.path.relpath(img_path, img_folder)
if args.save_predictions:
pred_instances_list = split_instances(pred_instances)
single_result = {
"img_path": img_path,
"instances": pred_instances_list
}
all_pred_instances.append(single_result)
# Save results to individual JSON files
json_filename = os.path.splitext(relative_path)[0] + ".json"
json_filepath = os.path.join(args.json_output_root, json_filename)
# 创建json没有的目录
os.makedirs(os.path.dirname(json_filepath), exist_ok=True)
with open(json_filepath, 'w') as f:
json.dump(single_result, f, indent='\t')
if args.img_output_root:
img_vis = visualizer.get_image()
img_output_path = os.path.join(args.img_output_root, relative_path)
# 创建img没有的目录
os.makedirs(os.path.dirname(img_output_path), exist_ok=True)
mmcv.imwrite(img_vis, img_output_path)
return all_pred_instances
def main():
"""Visualize the demo images.
Using mmdet to detect the human.
"""
parser = ArgumentParser()
parser.add_argument('det_config', help='Config file for detection')
parser.add_argument('det_checkpoint', help='Checkpoint file for detection')
parser.add_argument('pose_config', help='Config file for pose')
parser.add_argument('pose_checkpoint', help='Checkpoint file for pose')
parser.add_argument('--img-folder', type=str, default='images', help='Folder containing multiple images for processing')
parser.add_argument('--input', type=str, default='', help='Image/Video file')
parser.add_argument('--img-output-root', type=str, default='img_results', help='Directory to save the visualized images.')
parser.add_argument('--json-output-root', type=str, default='json_results', help='Directory to save the JSON results.')
parser.add_argument(
'--show',
action='store_true',
default=False,
help='whether to show img')
parser.add_argument(
'--save-predictions',
action='store_true',
default=True,
help='whether to save predicted results')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--det-cat-id',
type=int,
default=0,
help='Category id for bounding box detection model')
parser.add_argument(
'--bbox-thr',
type=float,
default=0.3,
help='Bounding box score threshold')
parser.add_argument(
'--nms-thr',
type=float,
default=0.3,
help='IoU threshold for bounding box NMS')
parser.add_argument(
'--kpt-thr',
type=float,
default=0.3,
help='Visualizing keypoint thresholds')
parser.add_argument(
'--draw-heatmap',
action='store_true',
default=False,
help='Draw heatmap predicted by the model')
parser.add_argument(
'--show-kpt-idx',
action='store_true',
default=False,
help='Whether to show the index of keypoints')
parser.add_argument(
'--skeleton-style',
default='mmpose',
type=str,
choices=['mmpose', 'openpose'],
help='Skeleton style selection')
parser.add_argument(
'--radius',
type=int,
default=3,
help='Keypoint radius for visualization')
parser.add_argument(
'--thickness',
type=int,
default=1,
help='Link thickness for visualization')
parser.add_argument(
'--show-interval', type=int, default=0, help='Sleep seconds per frame')
parser.add_argument(
'--alpha', type=float, default=0.8, help='The transparency of bboxes')
parser.add_argument(
'--draw-bbox', action='store_true', help='Draw bboxes of instances')
assert has_mmdet, 'Please install mmdet to run the demo.'
args = parser.parse_args()
assert args.img_folder, "Please specify the img-folder argument."
assert args.det_config is not None
assert args.det_checkpoint is not None
# build detector
detector = init_detector(
args.det_config, args.det_checkpoint, device=args.device)
detector.cfg = adapt_mmdet_pipeline(detector.cfg)
# build pose estimator
pose_estimator = init_pose_estimator(
args.pose_config,
args.pose_checkpoint,
device=args.device,
cfg_options=dict(
model=dict(test_cfg=dict(output_heatmaps=args.draw_heatmap))))
# build visualizer
pose_estimator.cfg.visualizer.radius = args.radius
pose_estimator.cfg.visualizer.alpha = args.alpha
pose_estimator.cfg.visualizer.line_width = args.thickness
visualizer = VISUALIZERS.build(pose_estimator.cfg.visualizer)
visualizer.set_dataset_meta(
pose_estimator.dataset_meta, skeleton_style=args.skeleton_style)
pred_instances_list = process_images(args, args.img_folder, detector, pose_estimator, visualizer)
if __name__ == '__main__':
main()
经过测试,完美运行,当然自己用的话要修改一些路径以及修改下文件循环那部分,因为我是按照自身需求写的。别的地方不用修改。导出格式是按照coco-wholebody格式输出的,含有133个关节点坐标及置信度。输出图片为标记好关节点的数据集,可以不输出,输出是为了检查自动打标签的质量。也可以用输出的json文件映射到那张图片检查关节点以及boundingbox效果,有时间我更新下check脚本。