这是一道比较典型的KM算法题 ,先练一下O( n ^ 4)的算法,程序如下:
#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<cstdio>
#define mem(a , b) memset(a , b , sizeof(a))
using namespace std ;
const int MAXN = 305 ;
const int INF = 0x7fffffff ;
int link[MAXN] , Lx[MAXN] , Ly[MAXN] ;
int visx[MAXN] ;
int visy[MAXN] ;
int W[MAXN][MAXN] ;
int n ;
int dfs(int u)
{
int v ;
visx[u] = 1 ;
for(v = 1 ; v <= n ; v ++)
{
if(Lx[u] + Ly[v] == W[u][v] && !visy[v])
{
visy[v] = 1 ;
if(link[v] == -1 || dfs(link[v]))
{
link[v] = u ;
return 1 ;
}
}
}
return 0 ;
}
void init()
{
int i , j ;
int tMax ;
for(i = 1 ; i <= n ; i ++)
{
tMax = -1 ;
for(j = 1 ; j <= n ; j ++)
{
scanf("%d" , &W[i][j]) ;
if(tMax < W[i][j]) tMax = W[i][j] ;
}
Lx[i] = tMax ;
}
}
void solve()
{
mem(Ly , 0) ;
mem(link , -1) ;
int i ;
for(i = 1 ; i <= n ; i ++)
{
while (1)
{
mem(visx , 0) ;
mem(visy , 0) ;
if(dfs(i)) break ;
else
{
int d = INF ;
int k , j ;
for(k = 1 ; k <= n ; k ++)
{
if(visx[k])
{
for(j = 1 ; j <= n ; j ++)
{
if(!visy[j])
{
if(d > Lx[k] + Ly[j] - W[k][j])
{
d = Lx[k] + Ly[j] - W[k][j] ;
}
}
}
}
}
for(k = 1 ; k <= n ; k ++)
{
if(visx[k])
{
Lx[k] -= d ;
}
if(visy[k])
{
Ly[k] += d ;
}
}
}
}
}
int ans = 0 ;
for(i = 1 ; i <= n ; i ++)
{
ans += W[ link[i] ][i] ;
}
printf("%d\n" , ans) ;
}
int main()
{
while (scanf("%d" , &n) != EOF)
{
init() ;
solve() ;
}
return 0 ;
}
下面是O(n ^ 3) 的写法:
#include<iostream>
#include<cstring>
#include<string>
#include<algorithm>
#include<cmath>
#include<cstdio>
#define mem(a ,b) memset(a , b , sizeof(a))
using namespace std ;
const int MAXN = 305 ;
const int INF = 0x7fffffff ;
int Lx[MAXN] , Ly[MAXN] , link[MAXN] ;
int visx[MAXN] , visy[MAXN] ;
int slack[MAXN] ;
int w[MAXN][MAXN] ;
int n ;
void init()
{
int i , j ;
int tMax = -1 ;
for(i = 1 ; i <= n ; i ++)
{
tMax = -1 ;
for(j = 1 ; j <= n ; j ++)
{
scanf("%d" , &w[i][j]) ;
if(tMax < w[i][j]) tMax = w[i][j] ;
}
Lx[i] = tMax ;
}
}
int dfs(int u)
{
int v ;
visx[u] = 1 ;
for(v = 1 ; v <= n ; v ++)
{
if(visy[v]) continue ;
int t = Lx[u] + Ly[v] - w[u][v] ;
if(t == 0)
{
visy[v] = 1 ;
if(link[v] == -1 || dfs(link[v]))
{
link[v] = u ;
return 1 ;
}
}
else
{
if(slack[v] > t)
slack[v] = t ;
}
}
return 0 ;
}
void solve()
{
mem(Ly , 0) ;
mem(link , -1) ;
int i ;
for(i = 1 ; i <= n ; i ++)
{
int j ;
for(j = 1 ; j <= n ; j ++)
slack[j] = INF ;
while (1)
{
mem(visx , 0) ;
mem(visy , 0) ;
if(dfs(i)) break ;
int d = INF ;
for(j = 1 ; j <= n ; j ++)
{
if(!visy[j] && d > slack[j])
d = slack[j] ;
}
for(j = 1 ; j <= n ; j ++)
{
if(visx[j]) Lx[j] -= d ;
if(visy[j]) Ly[j] += d ;
else slack[j] -= d ;
}
}
}
int ans = 0 ;
for(i = 1 ; i <= n ; i ++)
{
ans += w[ link[i] ][i] ;
}
printf("%d\n" , ans) ;
}
int main()
{
while (scanf("%d" , &n) != EOF)
{
init() ;
solve() ;
}
return 0 ;
}