import numpy as np
data_feature = np.random.rand(10000,512)
data_lable = np.random.randint(low = 1, high = 10, size = 10000)
query = np.random.rand(1,512)
query_label = 1
def AP(data_feature,data_label, query,query_label, k):
"""
:param data_label: 图像数据库的标签
:param query_label: 查询数据的标签
:param k: 前多少张图片
:return: map:前K张图片查询准确度
"""
data_lable = data_label
M = np.dot(data_feature, query.T)
rank = M.flatten().argsort()[::-1]
rank_label = data_lable[rank]
hit_index = np.argwhere(query_label == rank_label[:k])
print("query_label:", query_label)
print("rank_label:", rank_label[:10])
print("hit_index:", hit_index.flatten())
hit = np.cumsum(query_label == rank_label[:10])
print("hit:", hit)
if len(hit_index) == 0:
return "not hit"
print(hit[hit_index] / (hit_index + 1))
AP = np.mean(hit[hit_index] / (hit_index + 1), axis=0)
print("AP", AP)
return AP
if __name__ == '__main__':
AP = AP(data_feature, data_lable, query, query_label, k = 1)
print("AP:", AP)
query_label: 1
rank_label: [2 5 5 4 2 7 4 8 2 7 3 1 9 6 1 6 7 1 6 8]
hit_index: [11 14 17]
hit: [0 0 0 0 0 0 0 0 0 0 0 1 1 1 2 2 2 3 3 3]
[[0.08333333]
[0.13333333]
[0.16666667]]
AP [0.12777778]
AP: [0.12777778]