二分图的带权最大匹配
KM 算法只能在满足 带权最大匹配一定是完备匹配 的图中正确求解。
交错树:在匈牙利算法中,如果从某个左部节点出发寻找匹配失败,那么在 DFS 的过程中,所有访问过的节点(若干条路径),以及为了访问这些节点而经过的边,共同构成一棵树。这棵树被成为交错树。
顶标:如果任意 i , j i,j i,j 满足 A i + B j ≥ w i , j A_i+B_j\ge w_{i,j} Ai+Bj≥wi,j ,则把这些整数值 A i , B j A_i,B_j Ai,Bj 成为节点的顶标。
相等子图: 二分图中所有节点和满足 A i + B j = w i , j A_i+B_j=w_{i,j} Ai+Bj=wi,j 的边构成的子图。
定理:若相等子图中存在完备匹配,则这个完备匹配就是二分图的带权最大匹配。
我们的目标是通过不断调整可行顶标,使得相等子图是完美匹配。
首先初始化一组可行顶标:
A i = max 1 ≤ j ≤ n { w ( i , j ) } , B i = 0 A_i=\max_{1\le j\le n}\{w(i,j)\},B_i=0 Ai=max1≤j≤n{w(i,j)},Bi=0
然后选一个未匹配点,求增广路。如果找到增广路就增广,否则,会得到一个 交错树 。
令 S , T S,T S,T为二分图左边右边在交错树上的点, S ′ , T ′ S',T' S′,T′表示不再交错树上的点
考虑选择 △ = min { A i + B j − w ( i , j ) ∣ i ∈ S , j ∈ T ′ } \triangle=\min\{A_i+B_j-w(i,j)|i\in S,j\in T'\} △=min{Ai+Bj−w(i,j)∣i∈S,j∈T′},然后让 S S S中的顶标减去 △ \triangle △, T T T中的顶标加上 △ \triangle △,可以发现此时至少有一条新的边会加入相等子图(原来在交错树上的边不变),设这个点为 v v v,则:
- v v v是未匹配点,则找到增广路
- v v v和 S ′ S' S′中的点已经匹配,则将 v v v和这个点加入交错树
这样至多 n n n轮后就可以找到增广路。
复杂度 O ( N 3 ) O(N^3) O(N3) 。
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=505;
int n,m,g[N][N];
int sla[N],vis[N],la[N],lb[N],match[N],ans[N],pre[N];
inline int read() {
int x=0,f=1;
char c=getchar();
while(c<'0'||c>'9') {
if(c=='-') {
f=-1;
}
c=getchar();
}
while(c>='0'&&c<='9') {
x=(x<<1)+(x<<3)+c-'0';
c=getchar();
}
return x*f;
}
void bfs(int u) {
memset(vis,0,sizeof vis);
memset(sla,0x3f,sizeof sla);
int pos=0,x,p,delta;
for(match[pos]=u;match[pos];pos=p) {
vis[pos]=1,x=match[pos],delta=0x3f3f3f3f;
for(int y=1;y<=n;y++) {
if(vis[y]) {
continue;
}
if(la[x]+lb[y]-g[x][y]<sla[y]) {
sla[y]=la[x]+lb[y]-g[x][y];
pre[y]=pos;
}
if(sla[y]<delta) {
delta=sla[y];
p=y;
}
}
for(int y=0;y<=n;y++) {
if(vis[y]) {
la[match[y]]-=delta;
lb[y]+=delta;
}
else {
sla[y]-=delta;
}
}
}
for(;pos;pos=pre[pos]) match[pos]=match[pre[pos]];
}
int KM() {
memset(la,-0x3f,sizeof la);
for(int i=1;i<=n;i++) {
for(int j=1;j<=n;j++) {
la[i]=max(la[i],g[i][j]);
}
}
for(int i=1;i<=n;i++) {
bfs(i);
}
int tot=0;
for(int i=1;i<=n;i++) {
tot+=la[i]+lb[i];
}
return tot;
}
signed main() {
n=read(),m=read();
memset(g,-0x3f,sizeof g);
for(int i=1;i<=m;i++) {
int u=read(),v=read(),w=read();
g[u][v]=w;
}
printf("%lld\n",KM());
for(int i=1;i<=n;i++) {
printf("%lld ",match[i]);
}
}
upd on 2023/10/26 :修改了代码和一些不知所云的内容。