# coding = utf-8
'''
用匈牙利算法实现标签的匹配问题,并且输出最后的标签
'''
import numpy as np
from scipy.optimize import linear_sum_assignment
label = np.array([1,1,2,1,1,2,2,2,3,2,2,3,1,3,3,2,3])
gnd = np.array([2,2,2,2,2,2,3,3,3,3,3,3,1,1,1,1,1])
K = np.unique(gnd)
K_num = len(K)
cost_mat = np.array(np.zeros((K_num,K_num)))
for i in range(K_num):
temp_i = K[i]
idx = np.where(np.array(label)==temp_i)
for j in range(K_num):
temp_j = K[j]
h = gnd[idx]
t = np.where(np.array(h)!=temp_j)
cost_mat[i,j] = len(t[0])
print(cost_mat)
row_index,col_index = linear_sum_assignment(cost_mat)
assignment = col_index+1
print(assignment)
print('*'*200)
# assignment = assignment[::-1]
assignment_index = np.where(assignment)
result_label = np.array(np.zeros(label.shape))
for i in range(K_num):
temp_i = K[i]
print(temp_i)
idx = np.where(np.array(label)==temp_i)
print(idx)
result_label[idx] = assignment[i]
print(result_label)
print(cost_mat[row_index,col_index])
print(cost_mat[row_index,col_index].sum())
参考链接