1、介绍
人体关键点检测(Human Keypoints Detection)又称为人体姿态估计2D Pose,是计算机视觉中一个相对基础的任务,是人体动作识别、行为分析、人机交互等的前置任务。一般情况下可以将人体关键点检测细分为单人/多人关键点检测、2D/3D关键点检测,同时有算法在完成关键点检测之后还会进行关键点的跟踪,也被称为人体姿态跟踪。
2、数据处理
人体关键点数据集有很多,这里训练采用的是COCO2017人体关键点检测数据集。
首先,对数据集进行预处理,对person_keypoints_train2017.json
文件进行分析,解析出每张人体图片的关键点信息以及标注框信息。
from scipy import io
import json
from collections import defaultdict
json_file = r'C:\Users\Ning Hui\Desktop\annotations\person_keypoints_train2017.json'
import json
from collections import defaultdict
import numpy as np
import os
from tqdm import tqdm
import shutil
from matplotlib import pyplot as plt
import skimage.io as io
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon,Rectangle
def show(ax,keypoints,bbox):
c = (np.random.random((1, 3))*0.6+0.4).tolist()[0]
ls = [[15,13],[13,11],[16,14],[14,12],[11,12],[5,11],[6,12],[5,6],[5,7],
[6,8],[7,9],[8,10],[1,2],[0,1],[0,2],[1,3],[2,4],[3,5],[4,6]]
sks = np.array(ls)
kp = np.array(keypoints)
x = kp[0::3]
y = kp[1::3]
v = kp[2::3]
for sk in sks:
if np.all(v[sk]>0):
# 画点之间的连接线
plt.plot(x[sk],y[sk], linewidth=1, color=c)
# 画点
p = plt.plot(x[v>0], y[v>0],'o',markersize=4, markerfacecolor=c, markeredgecolor='k',markeredgewidth=1)
#ax.add_collection(p)
p = plt.plot(x[v>1], y[v>1],'o',markersize=4, markerfacecolor=c, markeredgecolor=c, markeredgewidth=1)
x, y, w, h = bbox[0],bbox[1],bbox[2],bbox[3]
# 多边形填充+矩形边界:
ax.add_patch(Polygon(xy=[[x, y], [x, y+h], [x+w, y+h], [x+w, y]], color='k', alpha=0.3))
ax.add_patch(Rectangle(xy=(x, y), width=w, height=h, fill=False, color=c, alpha=1))
with open(json_file, 'r',encoding='utf-8') as f:
data = json.load(f)
d = defaultdict(list)
for ann in tqdm(data['annotations']):
#print(ann)
id = ann['image_id']
key_points = ann['keypoints']
bbox = ann['bbox']
is_crowd = ann['iscrowd']
if not is_crowd and ann['num_keypoints']>8 and ann['category_id']==1:
d[id].append({
'category_id':0,
'key_points':key_points,
'bbox':bbox
})
for img in data['images']:
img_id = img['id']
img_file = img['file_name']
width = img['width']
height = img['height']
if img_id not in d.keys():
continue
d[img_id].append({'img_name':img_file,
'width':width,
'height':height,
})
print(len(d.keys()))
for id in tqdm(d.keys()):
img_prefix = 'G:/train2017'
dt = {'image_name':d[id][-1]['img_name'],
'category':0,
'width':d[id][-1]['width'],
'height':d[id][-1]['height'],
'key_points':[],
'bbox':[],
}
if dt['image_name'] in os.listdir(r'F:\keypoint_dataset\images'):
continue
save_path = 'F:/keypoint_dataset'
if d[id][-1]['img_name'] in os.listdir(r'F:\keypoint_dataset\images'):
continue
img_path = '%s/%s' % (img_prefix, d[id][-1]['img_name'])
json_path = '%s/labels/%s' % (save_path, d[id][-1]['img_name'].replace('.jpg','.json'))
# print(d[id])
I = io.imread('%s/%s' % (img_prefix, d[id][-1]['img_name']))
plt.imshow(I)
plt.axis('off')
ax = plt.gca()
for i in range(0,len(d[id])-1):
dt['key_points'].append(d[id][i]['key_points'])
dt['bbox'].append(d[id][i]['bbox'])
show(ax, d[id][i]['key_points'],d[id][i]['bbox'])
with open(json_path, 'w') as f:
json.dump(dt, f)
shutil.copy(img_path, 'F:/keypoint_dataset/images/')
完成后,我们得到每张图片文件以及对应的标注文件,标注格式为json,然后我们要转化为yolo格式,具体代码内容可以参考我的博客【数据集】Yolo人体关键点数据集处理
3、yolov8安装与权重下载
yolov8安装可以直接使用pio命令安装,速度较慢可以进行换源:
pip install ultralytics -i https://pypi.douban.com/simple/
权重下载
权重下载可以到github官网下载对应的权重,这里我下载的是YOLOv8n-pose
.
4、数据集配置与模型配置
4.1、数据集配置
path: /home/ningh/yolo/poseData # dataset root dir
train: images/train # train images (relative to 'path') 4 images
val: images/val # val images (relative to 'path') 4 images
test: images/test # test images (optional)
# Keypoints
kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
flip_idx: [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
# Classes
names:
0: person
4.2、模型配置
模型配置采用的是yolov8-pose.yaml
,可以在github中找到,
ultralytics/cfg/models/v8/yolov8s-pose.yaml
5、模型训练
from ultralytics import YOLO
# Load a model
model = YOLO("/home/ningh/yolo/ultralytics/ultralytics/cfg/models/v8/yolov8s-pose.yaml") # build a new model from YAML
model = YOLO("./weights/yolov8s-pose.pt") # load a pretrained model (recommended for training)
# model = YOLO("/home/ningh/yolo/ultralytics/ultralytics/cfg/models/v8/yolov8-pose.yaml").load("./weights/yolov8s-pose.pt") # build from YAML and transfer weights
# Train the model
results = model.train(data="./project/poseDetect/data.yaml", epochs=200, imgsz=640)
# results = model("https://ultralytics.com/images/bus.jpg")