题意: n n n 个村民, n n n 个房间,第 i i i 个村民对应第 j j j 个房间都有个收益 w i , j w_{i,j} wi,j ,且一个房间只能给一个村民,现在给这些村民分配房间,求获得的最大的收益值。
思路:很明显这是个二分图,与普通二分图不同的是这个是边带权的二分图。看题意即为求该二分图的一个完备匹配,其所有匹配边的和在所有完备匹配中最大,即为求二分图最佳完美匹配。
KM算法(专门用来求解二分图最佳完美匹配):
首先我们先对每个点赋予权值,村民为
x
x
x 集合,房间为
y
y
y 集合,则权值分别为
w
x
i
wx_i
wxi 和
w
y
j
wy_j
wyj ,
其满足
w
e
i
g
h
t
(
e
d
g
e
<
i
,
j
>
)
≤
w
x
i
+
w
y
j
weight(edge<i,j>) \leq wx_i+wy_j
weight(edge<i,j>)≤wxi+wyj 。初始时,其中一个集合权值全为零,另一个集合中一个点的权值为其所连边里最大边的权重。
据此,整个算法流程为:
1.初始化点权。
2.枚举一个集合
x
x
x 中的点。
3.对于当前枚举的点,在子图中寻找完备匹配
(子图中所有边
e
d
g
e
<
i
,
j
>
edge<i,j>
edge<i,j> 必须满足
w
e
i
g
h
t
(
e
d
g
e
<
i
,
j
>
)
=
w
x
i
+
w
y
j
weight(edge<i,j>) = wx_i+wy_j
weight(edge<i,j>)=wxi+wyj )。
4.若没有找到增广路,则修改部分点权。继续第3步。若找到则继续枚举下一个点直到枚举完。
对此,重点则在于如何修改点权:
每次修改为一个固定值
d
d
d ,集合
x
x
x 上位于匹配中的点的点权
+
d
+d
+d ,集合
y
y
y 上位于匹配中的点的点权
−
d
-d
−d 。其余点的点权不变。
那么对于一条边
e
d
g
e
<
i
,
j
>
edge<i,j>
edge<i,j> 我们可以分类讨论:
i
,
j
i,j
i,j均在增广路中,
w
x
i
−
d
+
w
y
j
+
d
wx_i-d+wy_j+d
wxi−d+wyj+d 则点权和与边权关系不变,边与子图关系也不变。
i
i
i 在,
j
j
j 不在,
w
x
i
−
d
+
w
y
j
wx_i-d+wy_j
wxi−d+wyj 则点权和与边权关系有变,则边与子图关系可能变化。
i
i
i 不在,
j
j
j 在,
w
x
i
+
w
y
j
+
d
wx_i+wy_j+d
wxi+wyj+d 则若边原来不在子图的修改之后也不在。
综上
d
d
d 的取值取第二种边的
m
i
n
(
w
x
i
+
w
y
j
−
w
e
i
g
h
t
(
e
d
g
e
<
i
,
j
>
)
)
min(wx_i+wy_j-weight(edge<i,j>))
min(wxi+wyj−weight(edge<i,j>))。
所以每次修改点权都要花费
n
2
n^2
n2 的时间单独求一遍
d
d
d。
整个算法时间复杂度就达到 n 4 n^4 n4 。
#include<cstdio>
#include<iostream>
#include<vector>
#include<cstring>
#include<string>
#include<queue>
#include<algorithm>
using namespace std;
#define NUM 305
#define mst(array,val,type,Count) memset(array,val,sizeof(type)*(Count))
class Graph
{
public:
int n, g[NUM][NUM];
int wx[NUM], wy[NUM];
bool build()
{
if (scanf("%d", &n) == EOF)return false;
for (int i = 1; i <= n; ++i)
{
wx[i] = 0;
for (int j = 1; j <= n; ++j)
{
scanf("%d", &g[i][j]);
wx[i] = max(wx[i], g[i][j]);
}
}
for (int i = 1; i <= n; ++i)wy[i] = 0;
return true;
}
};
class Kuhn_Munkres : public Graph
{
private:
int y_to[NUM];
bool vis_x[NUM], vis_y[NUM];
bool dfs(const int &vertex);
int KM();
public:
void solve()
{
while(build())
printf("%d\n", KM());
}
} G;
bool Kuhn_Munkres::dfs(const int &vertex)
{
vis_x[vertex] = true;
int weight;
for (int i = 1; i <= n; ++i)
{
if (vis_y[i])continue;
weight = wx[vertex] + wy[i] - g[vertex][i];
if (weight == 0)
{
vis_y[i] = true;
if (y_to[i] == -1 || dfs(y_to[i]))
{
y_to[i] = vertex;
return true;
}
}
}
return false;
}
int Kuhn_Munkres::KM()
{
int d, ans = 0;
memset(y_to, -1, sizeof(int) * (n + 1));
for (int i = 1; i <= n; ++i)
{
mst(vis_x, false, bool, n + 1), mst(vis_y, false, bool, n + 1);
while (!dfs(i))
{
d = -1;
for (int j = 1; j <= n; ++j)
{
if (!vis_x[j])continue;
for (int k = 1; k <= n; ++k)
{
if (vis_y[j])continue;
d = (d == -1) ? (wx[j] + wy[k] - g[j][k]) : min(d, wx[j] + wy[k] - g[j][k]);
}
}
for (int j = 1; j <= n; ++j)
{
if (vis_x[j])wx[j] -= d;
if (vis_y[j])wy[j] += d;
}
mst(vis_x, false, bool, n + 1), mst(vis_y, false, bool, n + 1);
}
}
for (int i = 1; i <= n; ++i)ans += wx[i];
for (int i = 1; i <= n; ++i)ans += wy[i];
return ans;
}
int main()
{
G.solve();
return 0;
}
优化:
KM算法可以优化至
n
3
n^3
n3 。这时需要用到一个
s
l
a
c
k
slack
slack 数组,用来存与集合
x
x
x 相对的集合
y
y
y 中每个节点对应的
d
d
d 值。所以每次修改的时候的
d
d
d 值即为
m
i
n
(
s
l
a
c
k
[
j
]
)
min(slack[j])
min(slack[j]) 。
注:每次修改点权时,若是不在增广路中的点的
s
l
a
c
k
slack
slack 也要
−
d
-d
−d。
又注:该优化仅在随机数据下可以达到
n
3
n^3
n3 的复杂度,所以特地出数据卡还是能卡成
n
4
n^4
n4 的复杂度。
#include<cstdio>
#include<iostream>
#include<vector>
#include<cstring>
#include<string>
#include<queue>
#include<algorithm>
using namespace std;
#define NUM 305
#define INF 0x3f3f3f3f
#define mst(array,val,type,Count) memset(array,val,sizeof(type)*(Count))
class Graph
{
public:
int n, g[NUM][NUM];
int wx[NUM], wy[NUM];
bool build()
{
if (scanf("%d", &n) == EOF)return false;
for (int i = 1; i <= n; ++i)
{
wx[i] = -INF;
for (int j = 1; j <= n; ++j)
{
scanf("%d", &g[i][j]);
wx[i] = max(wx[i], g[i][j]);
}
}
for (int i = 1; i <= n; ++i)wy[i] = 0;
return true;
}
};
class Kuhn_Munkres : public Graph
{
private:
int y_to[NUM], slack[NUM];
bool vis_x[NUM], vis_y[NUM];
bool dfs(const int &vertex);
int KM();
public:
void solve()
{
while(build())
printf("%d\n", KM());
}
} G;
bool Kuhn_Munkres::dfs(const int &vertex)
{
vis_x[vertex] = true;
int weight;
for (int i = 1; i <= n; ++i)
{
if (vis_y[i])continue;
weight = wx[vertex] + wy[i] - g[vertex][i];
if (weight == 0)
{
vis_y[i] = true;
if (y_to[i] == -1 || dfs(y_to[i]))
{
y_to[i] = vertex;
return true;
}
}
else
slack[i] = min(slack[i], weight);
}
return false;
}
int Kuhn_Munkres::KM()
{
int d, ans = 0;
memset(y_to, -1, sizeof(int) * (n + 1));
for (int i = 1; i <= n; ++i)
{
mst(slack, INF, int, n + 1), mst(vis_x, false, bool, n + 1), mst(vis_y, false, bool, n + 1);
while (!dfs(i))
{
d = INF;
for (int j = 1; j <= n; ++j)
{
if (vis_y[j])continue;
d = min(d, slack[j]);
}
for (int j = 1; j <= n; ++j)
{
if (vis_x[j])wx[j] -= d;
if (vis_y[j])wy[j] += d;
else slack[j] -= d;
}
memset(vis_x, false, sizeof(bool) * (n + 1)), memset(vis_y, false, sizeof(bool) * (n + 1));
}
}
for (int i = 1; i <= n; ++i)ans += wx[i];
for (int i = 1; i <= n; ++i)ans += wy[i];
return ans;
}
int main()
{
G.solve();
return 0;
}
最后,要想做到稳定的 n 3 n^3 n3 只需要把算法中的dfs改成bfs即可
#include<bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define LLINF 0x3f3f3f3f3f3f3f3f
#define lowbit(x) ((-x) & x)
#define ffor(i, d, u) for (int i = (d); i <= (u); ++i)
#define _ffor(i, u, d) for (int i = (u); i >= (d); --i)
#define mst(array, Num, Kind, Count) memset(array, Num, sizeof(Kind) * (Count))
#define mp(x, y) make_pair(x, y)
#define fi first
#define se second
#define N 305
#define M 1000005
typedef long long ll;
typedef double db;
typedef pair<ll, ll> pll;
typedef pair<int, int> pii;
typedef pair<db, db> pdd;
const db PI = acos(-1);
const ll MO = 1e9 + 7;
const ll Inv2 = (MO + 1) / 2;
const bool debug = true;
template <typename T>
inline void read(T &x)
{
x=0;char c;T t=1;while(((c=getchar())<'0'||c>'9')&&c!='-');
if(c=='-'){t=-1;c=getchar();}do(x*=10)+=(c-'0');while((c=getchar())>='0'&&c<='9');x*=t;
}
template <typename T, typename... Args>
inline void read(T &x, Args &... args)
{
read(x), read(args...);
}
template <typename T>
inline void write(T x)
{
int len=0;char c[21];if(x<0)putchar('-'),x*=(-1);
do{++len;c[len]=(x%10)+'0';}while(x/=10);_ffor(i,len,1)putchar(c[i]);
}
int n, pre[N], xto[N], yto[N];
ll wx[N], wy[N], slack[N], g[N][N];
bool visa[N], visb[N];
int q[N], head, tail;
bool check(int cur)
{
visb[cur] = true;
if (yto[cur])
{
if (!visa[yto[cur]])
q[++tail] = yto[cur], visa[yto[cur]] = true;
return false;
}
while (cur)
swap(cur, xto[yto[cur] = pre[cur]]);
return true;
}
void bfs(int s)
{
mst(visa, false, bool, n + 1), mst(visb, false, bool, n + 1), mst(slack, 0x3f, ll, n + 1);
head = 1, tail = 1, q[1] = s, visa[s] = true;
while (true)
{
while (head <= tail)
{
int u = q[head++];
ffor(i, 1, n)
{
if (visb[i])
continue;
ll value = wx[u] + wy[i] - g[u][i];
if (slack[i] >= value)
{
slack[i] = value, pre[i] = u;
if (slack[i] == 0 && check(i))
return;
}
}
}
ll d = LLINF;
ffor(i, 1, n) if (visb[i] == false && slack[i]) d = min(d, slack[i]);
ffor(i, 1, n)
{
if (visa[i]) wx[i] -= d;
if (visb[i]) wy[i] += d;
else slack[i] -= d;
}
head = 1, tail = 0;
ffor(i, 1, n) if (!visb[i] && slack[i] == 0 && check(i)) return;
}
}
ll solve()
{
mst(xto, 0, int, n + 1), mst(yto, 0, int, n + 1), mst(wy, 0, int, n + 1);
ffor(i, 1, n)
{
wx[i] = -LLINF;
ffor(j, 1, n) read(g[i][j]), wx[i] = max(wx[i], g[i][j]);
}
ffor(i, 1, n) bfs(i);
ll ans = 0;
ffor(i, 1, n) ans += g[yto[i]][i];
return ans;
}
inline int ac()
{
while (scanf("%d", &n) != EOF)
write(solve()), putchar('\n');
return 0;
}
int main()
{
ac();
return 0;
}
PS:KM算法应用于解决最佳完美匹配,若不是求完备匹配的最大权匹配,需要把不存在的边都添上,新添的边的边权赋为0即可。
而若是求最小权完备匹配,则把边权全部取相反数即可。最后求出来的结果再取反即可。