在demo文件里改,在run_demo函数里
改后的run_demo如下:
def run_demo(net, image_provider, height_size, cpu, track, smooth):
net = net.eval()
#if not cpu:
#net = net.cuda()
stride = 8 # 卷积步长
upsample_ratio = 4
num_keypoints = Pose.num_kpts # 18
previous_poses = []
delay = 1
i = 0
for img in image_provider: # 图像/视频文件
i += 1
orig_img = img.copy()
sk_img = np.zeros(img.shape)
heatmaps, pafs, scale, pad = infer_fast(net, img, height_size, stride, upsample_ratio, cpu)
# 记录总共检测到的关键点个数
total_keypoints_num = 0
# [x,y,conf,id] , 从extract_keypoints函数里返回
all_keypoints_by_type = []
for kpt_idx in range(num_keypoints): # 19th for bg 19种连接方向(线段)
# 提取第i个关键点的个数,把关键点的[x,y,conf,id]加入到all_keypoints_by_type
# 关键点个数使用total_keypoints_num记录
total_keypoints_num += extract_keypoints(heatmaps[:, :, kpt_idx], all_keypoints_by_type, total_keypoints_num)
pose_entries, all_keypoints = group_keypoints(all_keypoints_by_type, pafs) # pose_entries:所有分配的人
for kpt_id in range(all_keypoints.shape[0]): # 关键点个数 upsample_ratio上采样率
print(all_keypoints[kpt_id])
# x
all_keypoints[kpt_id, 0] = (all_keypoints[kpt_id, 0] * stride / upsample_ratio - pad[1]) / scale
# y
all_keypoints[kpt_id, 1] = (all_keypoints[kpt_id, 1] * stride / upsample_ratio - pad[0]) / scale
current_poses = [] # 当前画
for n in range(len(pose_entries)):
if len(pose_entries[n]) == 0:
continue
pose_keypoints = np.ones((num_keypoints, 2), dtype=np.int32) * -1
for kpt_id in range(num_keypoints): # 18 局部变量?
if pose_entries[n][kpt_id] != -1.0: # keypoint was found
# 坐标写入
pose_keypoints[kpt_id, 0] = int(all_keypoints[int(pose_entries[n][kpt_id]), 0])
pose_keypoints[kpt_id, 1] = int(all_keypoints[int(pose_entries[n][kpt_id]), 1])
pose = Pose(pose_keypoints, pose_entries[n][18])
current_poses.append(pose)
# 图中画出骨架
savepath = 'D:\B\sk_data/'
savename = str(i) + '.jpg'
if track:
track_poses(previous_poses, current_poses, smooth=smooth)
previous_poses = current_poses
for pose in current_poses:
pose.draw(img)
pose.draw(sk_img)
img = cv2.addWeighted(orig_img, 0.6, img, 0.4, 0)
for pose in current_poses:
cv2.rectangle(img, (pose.bbox[0], pose.bbox[1]),
(pose.bbox[0] + pose.bbox[2], pose.bbox[1] + pose.bbox[3]), (0, 255, 0))
cv2.rectangle(sk_img, (pose.bbox[0], pose.bbox[1]),
(pose.bbox[0] + pose.bbox[2], pose.bbox[1] + pose.bbox[3]), (0, 255, 0))
if track:
cv2.putText(img, 'id: {}'.format(pose.id), (pose.bbox[0], pose.bbox[1] - 16),
cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 255))
cv2.putText(sk_img, 'id: {}'.format(pose.id), (pose.bbox[0], pose.bbox[1] - 16),
cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 255))
cv2.imshow('Lightweight Human Pose Estimation Python Demo', img)
key = cv2.waitKey(delay)
cv2.imwrite(os.path.join(savepath, savename), sk_img)
#cv2.imwrite('savepath/{}.jpg'.format(i), sk_img)
if key == 27: # esc
return
elif key == 112: # 'p'
if delay == 1:
delay = 0
else:
delay = 1