import json
import numpy as np
from sklearn.metrics import mean_squared_error
# 加载预测数据和真实数据
with open('/mnt/data/eval_results.json', 'r') as f:
pred_data = json.load(f)
with open('/mnt/data/person_keypoints_val2017.json', 'r') as f:
gt_data = json.load(f)
# 将真实数据按 image_id 组织
gt_keypoints = {}
for annotation in gt_data['annotations']:
image_id = annotation['image_id']
if image_id not in gt_keypoints:
gt_keypoints[image_id] = []
gt_keypoints[image_id].append(annotation['keypoints'])
# 将预测数据按 image_id 组织
pred_keypoints = {}
for annotation in pred_data:
image_id = annotation['image_id']
if image_id not in pred_keypoints:
pred_keypoints[image_id] = []
pred_keypoints[image_id].append(annotation['keypoints'])
# 计算 RMSE
def calculate_rmse(gt_points, pred_points):
# 转换为 NumPy 数组并重塑为 (n, 3) 形状
gt_points = np.array(gt_points).reshape(-1, 3)
pred_points = np.array(pred_points).reshape(-1, 3)
# 仅保留可见且坐标不为零的关键点 (visibility > 0 且坐标不为零)
visible = (gt_points[:, 2] > 0) & (gt_points[:, 0] != 0) & (gt_points[:, 1] != 0)
gt_points = gt_points[visible][:, :2]
pred_points = pred_points[visible][:, :2]
if len(gt_points) == 0:
return np.nan
return np.sqrt(mean_squared_error(gt_points, pred_points))
# 获取共同的 image_id
image_ids = set(gt_keypoints.keys()).intersection(set(pred_keypoints.keys()))
rmse_list = []
for image_id in image_ids:
gt_points = sum(gt_keypoints[image_id], [])
pred_points = sum(pred_keypoints[image_id], [])
rmse = calculate_rmse(gt_points, pred_points)
rmse_list.append(rmse)
# 计算所有图片的平均 RMSE
overall_rmse = np.nanmean(rmse_list)
print(f'Overall RMSE: {overall_rmse}')
compute_rmse2
最新推荐文章于 2024-07-29 21:16:24 发布