LSP数据集和MPII数据集是常见的人体姿态识别公开数据集,数据中的label的保存格式为mat格式(mat格式为matlab存储文件的格式)。
为了方便python和c++调用数据的label,常常会把存储格式设为 txt 格式和 json 格式。为了方便看到数据集中的关键点,作者将关键点进行了 可视化展示。
LSP数据集
import glob
import os
from scipy.io import loadmat
import numpy as np
from PIL import Image
import cv2
colors = [(50, 0, 0), (100, 0, 0), (150, 0, 0), (200, 0, 0), (250, 0, 0), (0, 50, 0), (0, 100, 0),(0, 150, 0), (0, 200, 0),(0, 250, 0), (0, 0, 50), (0, 0, 100), (0, 0, 150), (0, 0, 200), (0, 0, 250), (50, 50, 50)]
def save_joints(mat_path,image_path,save_path):
"""
mat_path 是 lsp数据集mat文件所在地址,包含mat文件名
image_path 是 lsp数据集图像的地址,不包含图像名
save_path 是 要将lsp数据集中关键点保存的地址名
lsp数据集共2000张图片
"""
joints = loadmat(mat_path)
joints = joints["joints"].transpose(2,0,1)
joints = joints[:,:2,:]
num = 0
for img_path in glob.glob("%s/*.jpg" %image_path):
img_name = img_path.split("\\")[-1]
img = Image.open(img_path)
img = np.array(img,dtype=np.uint8)
img = cv2.cvtColor(img,cv2.COLOR_RGB2BGR)
cen_points = joints[num,...]
points_num = cen_points.shape[-1]
point_dict = {}
for points_ in range(points_num):
point_x = cen_points[0,points_]
point_y = cen_points[1,points_]
point_dict[str(points_)] = [int(point_x),int(point_y)]
img1 = cv2.circle(img, (int(point_x), int(point_y)), 5, colors[points_],
thickness=-1)
img1 = cv2.putText(img, str(points_),
(int(point_x) + 10, int(point_y)),
cv2.FONT_HERSHEY_SIMPLEX, 1, colors[points_], 1)
with open(os.path.join(save_path,img_name.split(".")[0]+".txt"),"w") as img_txt:
img_txt.write(str(point_dict))
img_txt.close()
num += 1
# 若不想看图片中关键点的位置是否准确,请注释掉后面两行
cv2.imshow("img",img)
cv2.waitKey()
MPII数据集
import os
from scipy.io import loadmat
import numpy as np
from PIL import Image
import cv2
def save_joints(mat_path,image_path,save_path):
joint_data_fn = save_path
mat = loadmat(mat_path)
mpii_images = image_path
for i, (anno, train_flag) in enumerate(
zip(mat['RELEASE']['annolist'][0, 0][0],
mat['RELEASE']['img_train'][0, 0][0])):
img_fn = anno['image']['name'][0, 0][0]
img_path = os.path.join(mpii_images, img_fn)
if not os.path.exists(img_path):
print("error, not exist", img_path)
continue
img = Image.open(os.path.join(image_path, img_fn))
img = np.array(img,dtype=np.uint8)
img1 = cv2.cvtColor(img,cv2.COLOR_RGB2BGR)
height, width, _ = img1.shape
train_flag = int(train_flag)
if 'x1' in str(anno['annorect'].dtype):
head_rect = zip(
[x1[0, 0] for x1 in anno['annorect']['x1'][0]],
[y1[0, 0] for y1 in anno['annorect']['y1'][0]],
[x2[0, 0] for x2 in anno['annorect']['x2'][0]],
[y2[0, 0] for y2 in anno['annorect']['y2'][0]])
if 'annopoints' in str(anno['annorect'].dtype):
# only one person
annopoints = anno['annorect']['annopoints'][0]
head_x1s = anno['annorect']['x1'][0]
head_y1s = anno['annorect']['y1'][0]
head_x2s = anno['annorect']['x2'][0]
head_y2s = anno['annorect']['y2'][0]
image_write = ""
for annopoint, head_x1, head_y1, head_x2, head_y2 in zip(
annopoints, head_x1s, head_y1s, head_x2s, head_y2s):
if annopoint != []:
head_rect = [float(head_x1[0, 0]),
float(head_y1[0, 0]),
float(head_x2[0, 0]),
float(head_y2[0, 0])]
# build feed_dict
feed_dict = {}
feed_dict['width'] = width
feed_dict['height'] = height
# joint coordinates
annopoint = annopoint['point'][0, 0]
j_id = [str(j_i[0, 0]) for j_i in annopoint['id'][0]]
x = [x[0, 0] for x in annopoint['x'][0]]
y = [y[0, 0] for y in annopoint['y'][0]]
joint_pos = {}
for _j_id, (_x, _y) in zip(j_id, zip(x, y)):
joint_pos[str(_j_id)] = [float(_x), float(_y)]
# joint_pos = fix_wrong_joints(joint_pos)
# visiblity list
if 'is_visible' in str(annopoint.dtype):
vis = [v[0] if v else [0]
for v in annopoint['is_visible'][0]]
vis = dict([(k, int(v[0])) if len(v) > 0 else v
for k, v in zip(j_id, vis)])
else:
vis = None
feed_dict['x'] = x
feed_dict['y'] = y
feed_dict['vis'] = vis
feed_dict['filename'] = img_fn
img1 = cv2.rectangle(img1, (int(head_rect[0]), int(head_rect[1])),
(int(head_rect[2]), int(head_rect[3])),
color=(255, 0, 0), thickness=4)
colors = [(50, 0, 0), (100, 0, 0), (150, 0, 0), (200, 0, 0), (250, 0, 0), (0, 50, 0), (0, 100, 0),(0, 150, 0), (0, 200, 0),(0, 250, 0), (0, 0, 50), (0, 0, 100), (0, 0, 150), (0, 0, 200), (0, 0, 250), (50, 50, 50)]
for mm in range(len(joint_pos)):
img1 = cv2.circle(img1, (int(joint_pos[str(list(joint_pos.keys())[mm])][0]), int(joint_pos[str(list(joint_pos.keys())[mm])][1])),10, colors[mm], thickness=-1)
img1 = cv2.putText(img1, str(mm),
(int(joint_pos[str(list(joint_pos.keys())[mm])][0]) + 10,
int(joint_pos[str(list(joint_pos.keys())[mm])][1])),
cv2.FONT_HERSHEY_SIMPLEX, 1, colors[mm], 1)
data = {
# 'filename': img_fn,
'train': train_flag,
'head_rect': head_rect,
'is_visible': vis,
'joint_pos': joint_pos
}
image_write = image_write + str(data) + "\n"
fp = open(os.path.join(joint_data_fn,img_fn.split(".")[0]+".txt"), 'w')
fp.write(image_write)
fp.close()
print(f" {img_fn}.txt 保存成功")
# 若不想看图片中关键点的位置是否准确,请注释掉后面三行
cv2.imshow("img_video", img1)
cv2.waitKey()
cv2.destroyAllWindows()
注:MPII数据集label转化做过参考。
参考 :GitHub - Fangyh09/PoseDatasets: Filter multiple pose datasets (coco, flic, lsp, mpii, ai_challenge)