该代码对于一张图进行heatmap图片的可视化,会存储coco的17组关节点图片,即每一个heatmap包含图片中所有人体的同一关节点位置。
##生成heatmap
import os
import numpy as np
import json
import cv2
from itertools import groupby
import random
from matplotlib import pyplot as plt
dataset_dir = "D:/send_paper/COCO val2017/val2017/"
dataset_save = "D:/send_paper/COCO val2017/result/"
def normalization(data):
_range = np.max(data) - np.min(data)
return (data-np.min(data))/_range*255
def CenterLabelHeatMap(img,kpts,sigma):
img_height,img_width ,_ = img.shape
# img = cv2.applyColorMap(img,3)
X1 = np.linspace(1,img_width,img_width)
Y1 = np.linspace(1,img_height,img_height)
heatmap=list()
result = 0
for num in range(len(kpts)):
[X,Y]=np.meshgrid(X1,Y1)
X = X-kpts[num][0]
Y = Y-kpts[num][1]
D2 = X*X+Y*Y
E2 = 2.0*sigma*sigma
Exponent = D2/E2
heatmap_1 =np.exp(-Exponent)
heatmap_1 = normalization(heatmap_1)
heatmap_1 = np.array(heatmap_1,np.uint8)
heatmap.append(heatmap_1)
for hm in heatmap:
result+=hm
heatmap = cv2.applyColorMap(result,cv2.COLORMAP_JET)
img = img*0.3+heatmap*0.7
return img,heatmap
def CenterGaussianHeatMap(img,c_x,c_y,variance):
img_height, img_width, _ = img.shape
img = cv2.applyColorMap(img, 2)
gaussian_map = np.zeros((img_height,img_width,3))
for x_p in range(img_width):
for y_p in range(img_height):
dist_sq = (x_p-c_x)*(x_p-c_x)+(y_p-c_y)*(y_p-c_y)
exponent = dist_sq/2.0/variance/variance
gaussian_map[y_p,x_p,0] = np.exp(-exponent)
gaussian_map = normalization(gaussian_map)
# img[:,:,0]=img[:,:,0]*0.1+gaussian_map[:,:,0]*0.9
return img
"""
先用groupby将同一个key下的不同人体相同关节点位置取出来
"""
with open("D:/send_paper/keypoints_val2017_results_0.json","r") as load_f:
load_dict = json.load(load_f)
ret = list()
ret_num = list()
for group_num,group in groupby(load_dict,lambda x: x.get("image_id")):
ret.append(list(group))
ret_num.append(group_num)
ret_result = zip(ret_num,ret)
imgIds_old = 0
image = cv2.imread(os.path.join(dataset_dir, str(397133).zfill(12) + '.jpg'))
for dict_num in ret_result:
imgIds = dict_num[0]
if imgIds==127263 :
image_path = os.path.join(dataset_dir, str(imgIds).zfill(12) + '.jpg')
image = cv2.imread(image_path)
for kpts_num in range(17):
joints = list()
for person_num in range(len(dict_num[1])):
person_mess = np.array(dict_num[1][person_num]['keypoints']).reshape(-1,3)
joints.append(person_mess[kpts_num])
img,heatmap=CenterLabelHeatMap(image,joints,4)
cv2.imwrite("D:/send_paper/{}.jpg".format(kpts_num),img)