问题
最近遇到一个问题,需要对两组数据进行匹配,使得匹配后两组数据的欧氏距离最小。将两组数据的欧氏距离的倒数转成矩阵形式,于是问题转化为了求矩阵不同行不同列和的最大值。
使用KM算法可以比较简单的实现这个需求。
算法原理
KM算法解决不同行不同列求和最大值问题
KM算法详解+模板(男生女生配)
代码实现
按照上述的讲解,我写了一份自己的KM算法实例,尽可能加好了注释
#include <iostream>
#include <stdio.h>
#include <string.h>
#include <vector>
#define N 5
#define INF 9999999999
int w[N][N] = {
//Y
//X
{1,3,4,3,5},
{4,8,9,6,3},
{4,3,2,8,9},
{1,2,4,8,1},
{8,9,3,1,8},
};//记下权值W[xindex][yindex]
int cx[N] = {0}; //记下X集合的标杆值
int cy[N] = {0}; //记下Y集合的标杆值
bool visX[N]={false}; //本轮是否访问过x
bool visY[N]={false}; //本轮是否访问过y
int match[N] = {-1}; //记下所有y的匹配情况,默认为-1
/*
true 表示x匹配到了y
false 表示x没法匹配到y
*/
bool dfs(int xidx){
visX[xidx] = true; //设置当前x已经访问过
for (int yidx = 0 ; yidx < N ; yidx++){
//如果当前的y没有访问过且cx+cy-w=0说明当前x和当前y正好是匹配的情况
if (!visY[yidx] && (cx[xidx] + cy[yidx] - w[xidx][yidx])==0){
visY[yidx] = true;
//如果y目前还没有匹配或者y已经匹配了,但是y匹配的那个x可以找到另一个更好的
if (match[yidx]==-1 || dfs(match[yidx])){
match[yidx] = xidx;
return true;
}
}
}
return false;
}
int KM(){
memset(match, -1, sizeof(match)); // 初始y的匹配情况,均为-1,表示未匹配
memset(cy, 0, sizeof(cy)); // 初始化y的标杆值为0
//使用贪心算法初始化x的标杆值,为最大的权值
for(int i = 0; i <= N; i++)
for (int j = 0; j <= N; j++)
cx[i] = std::max(cx[i], w[i][j]);
for(int i = 0; i < N ; i++){ //遍历x
//对每一个x,循环扫描
while(1){
//每一轮重置访问变量xy
memset(visX, 0, sizeof(visX));
memset(visY, 0, sizeof(visY));
if (dfs(i)) break; //如果已经找到匹配的y,则退出
//如果找不到可以匹配的y,开始寻找变动最小的调整方案,求出调整后最小的权值变化情况d
int d = INF;
for (int xidx = 0 ; xidx < N;xidx++){
if (visX[xidx]){
for (int yidx = 0 ; yidx < N ; yidx ++){
if(!visY[yidx]){
d = std::min(d, cx[xidx] + cy[yidx] - w[xidx][yidx]);
}
}
}
}
if (d==INF)
return -1;
//根据最小权值变化量d,对所有当前轮次访问过的x,其标杆值-=d,对访问过的y,其标杆值+=d
for (int j = 0; j <= N; j++){
if (visX[j]) cx[j] -= d;
if (visY[j]) cy[j] += d;
}
std::cout<<"d:"<<d<<std::endl;
for(int j = 0 ; j < N ;j++){
printf("y[%d]-->x[%d]\n",j,match[j]);
}
for(int j = 0 ; j < N ;j++){
printf("cx[%d]%d\n",j,cx[j]);
}
for(int j = 0 ; j < N ;j++){
printf("cy[%d]%d\n",j,cy[j]);
}
}
}
//打印匹配情况
int val = 0;
//遍历j
for(int i = 0 ; i < N ;i++){
printf("y[%d]-->x[%d]\n",i,match[i]);
val += w[match[i]][i];
}
std::cout<<val<<std::endl;
}
int main(){
KM();
return 0;
}