compute_rmse

import json
import numpy as np
from sklearn.metrics import mean_squared_error

# 加载预测数据和真实数据
with open('/mnt/data/eval_results.json', 'r') as f:
    pred_data = json.load(f)

with open('/mnt/data/person_keypoints_val2017.json', 'r') as f:
    gt_data = json.load(f)

# 将真实数据按 image_id 组织
gt_keypoints = {}
for annotation in gt_data['annotations']:
    image_id = annotation['image_id']
    if image_id not in gt_keypoints:
        gt_keypoints[image_id] = []
    gt_keypoints[image_id].append(annotation['keypoints'])

# 将预测数据按 image_id 组织
pred_keypoints = {}
for annotation in pred_data:
    image_id = annotation['image_id']
    if image_id not in pred_keypoints:
        pred_keypoints[image_id] = []
    pred_keypoints[image_id].append(annotation['keypoints'])

# 计算 RMSE
def calculate_rmse(gt_points, pred_points):
    # 仅保留可见的关键点 (visibility > 0)
    gt_points = np.array(gt_points).reshape(-1, 3)
    pred_points = np.array(pred_points).reshape(-1, 3)
    visible = gt_points[:, 2] > 0
    gt_points = gt_points[visible][:, :2]
    pred_points = pred_points[visible][:, :2]
    
    if len(gt_points) == 0:
        return np.nan
    
    return np.sqrt(mean_squared_error(gt_points, pred_points))

image_ids = set(gt_keypoints.keys()).intersection(set(pred_keypoints.keys()))
rmse_list = []

for image_id in image_ids:
    gt_points = sum(gt_keypoints[image_id], [])
    pred_points = sum(pred_keypoints[image_id], [])
    rmse = calculate_rmse(gt_points, pred_points)
    rmse_list.append(rmse)

# 计算所有图片的平均 RMSE
overall_rmse = np.nanmean(rmse_list)
print(f'Overall RMSE: {overall_rmse}')

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值