匈牙利算法
问题
解决二分图最大匹配问题。
思想
不断寻找原有匹配
M
M
M 的增广路径以增大匹配数。
对每次匹配左边点
x
x
x,不断通过右边点
y
y
y 跳到左边
y
y
y 的对象点
x
∗
x*
x∗,直至
x
x
x 右边点
y
y
y 没有对象。
代码
#include<bits/stdc++.h>
#define pb push_back
using namespace std;
const int N=509;
int n,m,e,ans;
int vis[N],t[N];
vector<int>v[N];
bool dfs(int x)
{
for(auto y:v[x])
if(!vis[y])
{
vis[y]=1;
if(!t[y]||dfs(t[y]))
{
t[y]=x;
return 1;
}
}
return 0;
}
int main()
{
cin>>n>>m>>e;
for(int i=1;i<=e;i++)
{
int x,y;
cin>>x>>y;
if(x>n||y>m) continue;
v[x].pb(y);
}
for(int i=1;i<=n;i++)
{
memset(vis,0,sizeof(vis));
if(dfs(i)) ans++;
}
cout<<ans<<endl;
return 0;
}
KM算法
问题
求二分图最大权完美匹配。
步骤
-
初始化可行顶标的值 (设定 l x , l y lx,ly lx,ly 的初始值)
l x i = m a x ( w [ i ] [ j ] ) lx_i=max(w[i][j]) lxi=max(w[i][j]), l y i = 0 ly_i=0 lyi=0 -
用匈牙利算法寻找相等子图的完备匹配
完备匹配满足 l x [ x ] + l y [ y ] = w [ x ] [ y ] lx[x]+ly_[y]=w[x][y] lx[x]+ly[y]=w[x][y] -
若未找到增广路则修改可行顶标的值
对路径上的左节点 l x [ i ] − = m i n n lx[i]-=minn lx[i]−=minn ,对路径上的右节点 l y [ i ] + = m i n n ly[i]+=minn ly[i]+=minn
而 m i n n minn minn 为路径上左节点与非路径上右节点 l x [ x ] + l y [ y ] − w [ x ] [ y ] lx[x]+ly_[y]-w[x][y] lx[x]+ly[y]−w[x][y] 的最小值 -
重复(2)(3)直到找到相等子图的完备匹配为止
个人理解
初始顶标和可能是最大权,在之后每次匹配下会不断发生冲突,因此需解决在已有匹配下调整匹配使得匹配最优,所用思想即为贪心
每次匹配下,若有冲突,路径上左节点有 m m m 个,右节点 m − 1 m-1 m−1 个,最优解决方法即为:对路径上 m m m 个左节点 − m i n n -minn −minn ,对路径上 m − 1 m-1 m−1 个右节点 + m i n n +minn +minn 。
这样,相当于左顶标流向右顶标,保证了路径上 l x [ x ] + l y [ y ] = w [ x ] [ y ] lx[x]+ly_[y]=w[x][y] lx[x]+ly[y]=w[x][y] ,也使得路径上某个左节点(取 m i n n minn minn 的节点)可以与非路径上右节点(取 m i n n minn minn )匹配。同时,总的顶标和仅减小一个 m i n n minn minn,权值也最优。
博文推荐
KM算法详解+模板 -本文没有给出KM算法的原理,只是模拟了一遍算法的过程。明白易懂算法流程
二分图的最佳完美匹配——KM算法 -详细证明
代码
#include<bits/stdc++.h>
using namespace std;
const int N=29;
int n,minn,ans;
int p[N][N],q[N][N],lx[N],ly[N],visx[N],visy[N],t[N];
bool dfs(int x)
{
visx[x]=1;
for(int y=1;y<=n;y++)
if(!visy[y])
{
int temp=lx[x]+ly[y]-p[x][y];
if(!temp)
{
visy[y]=1;
if(!t[y]||dfs(t[y]))
{
t[y]=x;
return true;
}
}
else minn=min(minn,temp);
}
return false;
}
void KM()
{
for(int i=1;i<=n;i++)
while(1)
{
minn=1e9+7;
memset(visx,0,sizeof(visx));
memset(visy,0,sizeof(visy));
if(dfs(i)) break;
for(int j=1;j<=n;j++)
{
if(visx[j]) lx[j]-=minn;
if(visy[j]) ly[j]+=minn;
}
}
}
int main()
{
cin>>n;
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
cin>>p[i][j];
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
{
cin>>q[i][j];
p[j][i]*=q[i][j];
lx[j]=max(lx[j],p[j][i]);
}
KM();
for(int i=1;i<=n;i++)
ans+=p[t[i]][i];
cout<<ans<<endl;
return 0;
}
优化
每次 d f s dfs dfs 都是从头开始的,重复较多,因此引入 s l a c k [ j ] slack[j] slack[j] 数组保存跟右部节点 j j j 相连的节点 i i i 的 l x [ i ] + l y [ j ] − w [ i ] [ j ] lx[i]+ly[j]-w[i][j] lx[i]+ly[j]−w[i][j] 的最小值
每次 b f s bfs bfs 时更新 s l a c k [ i ] slack[i] slack[i] 的值,为 0 0 0 的节点若无对象直接更新路径返回,有对象则将对象加入队列中不断更新 s l a c k slack slack 值
寻找 m i n n minn minn 时直接从非路径上节点寻找 m i n min min ,路径上 l x [ i ] − = m i n n , l y [ i ] + = m i n n , lx[i]-=minn,ly[i]+=minn, lx[i]−=minn,ly[i]+=minn,,非路径上 s l a c k [ i ] − = m i n n slack[i]-=minn slack[i]−=minn
将非路径上 s l a c k [ i ] = 0 slack[i]=0 slack[i]=0 (可能在之后形成增广路径)的节点的 p y [ i ] py[i] py[i] (右节点的对象,同时为路径上的节点)加入队列
不断重复以上操作,期间 v i s x , v i s y visx,visy visx,visy 不重新赋 0 0 0
时间复杂度 O ( n 3 ) O(n^3) O(n3)
代码
#include<bits/stdc++.h>
#define LL long long
#define pb push_back
using namespace std;
const int N=509;
int n,m;
LL minn,ans;
int visx[N],visy[N],px[N],py[N],pre[N];
LL lx[N],ly[N],e[N][N],slack[N];
queue<int>q;
void aug(int y)
{
while(y)
{
int ny=px[pre[y]];
px[pre[y]]=y;
py[y]=pre[y];
y=ny;
}
}
void bfs(int x)
{
while(q.size()) q.pop();
q.push(x);
while(1)
{
while(q.size())
{
int u=q.front(); q.pop();
visx[u]=1;
for(int i=1;i<=n;i++)
if(!visy[i])
{
if(lx[u]+ly[i]-e[u][i]<slack[i])
{
slack[i]=lx[u]+ly[i]-e[u][i];
pre[i]=u;
if(!slack[i])
{
visy[i]=1;
if(!py[i]) {aug(i); return;}
q.push(py[i]);
}
}
}
}
LL d=1e18+9;
for(int i=1;i<=n;i++)
if(!visy[i])
d=min(d,slack[i]);
for(int i=1;i<=n;i++)
{
if(visx[i]) lx[i]-=d;
if(visy[i]) ly[i]+=d;
else slack[i]-=d;
}
for(int i=1;i<=n;i++)
if(!visy[i]&&!slack[i])
{
visy[i]=1;
if(!py[i]) {aug(i); return;}
q.push(py[i]);
}
}
}
void KM()
{
for(int i=1;i<=n;i++)
{
memset(visx,0,sizeof(visx));
memset(visy,0,sizeof(visy));
memset(slack,127,sizeof(slack));
bfs(i);
}
}
int main()
{
scanf("%d%d",&n,&m);
memset(e,128,sizeof(e));
for(int i=1;i<=m;i++)
{
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
e[x][y]=z;
}
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
lx[i]=max(lx[i],e[i][j]);
KM();
for(int i=1;i<=n;i++)
ans+=e[py[i]][i];
printf("%lld\n",ans);
for(int i=1;i<=n;i++)
printf("%d ",py[i]);
return 0;
}