def mean_ap(
distmat,
query_ids=None,
gallery_ids=None,
query_cams=None,
gallery_cams=None,
average=True):
m, n = distmat.shape
# Sort and find correct matches
indices = np.argsort(distmat, axis=1)
matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
# Compute AP for each query
aps = np.zeros(m)
is_valid_query = np.zeros(m)
for i in range(m):
# Filter out the same id and same camera
valid = ((gallery_ids[indices[i]] != query_ids[i]) |
(gallery_cams[indices[i]] != query_cams[i]))
# 去掉那些id相同,并且摄像头相同的样例 !((gallery_ids[indices[i]]==query_ids[i])&(gallery_cams[indices[i]] == query_cams[i]))
y_true = matches[i, valid]
y_score = -distmat[i][indices[i]][valid]
if not np.any(y_true): continue
is_valid_query[i] = 1
aps[i] = average_precision_score(y_true, y_score)
#sklearn 当中用来计算AP的函数,y_true 代表真实的标签 y_score代表检索分数
if len(aps) == 0:
raise RuntimeError("No valid query")
if average:
return float(np.sum(aps)) / np.sum(is_valid_query)
return aps, is_valid_query
average_precision 计算方式:https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html
行人重识别mAP的源代码
最新推荐文章于 2024-07-16 13:34:36 发布