HDU - 2255 奔小康赚大钱 二分图最佳完美匹配——KM算法

题目链接

题意: n n n 个村民, n n n 个房间,第 i i i 个村民对应第 j j j 个房间都有个收益 w i , j w_{i,j} wi,j ,且一个房间只能给一个村民,现在给这些村民分配房间,求获得的最大的收益值。

思路:很明显这是个二分图,与普通二分图不同的是这个是边带权的二分图。看题意即为求该二分图的一个完备匹配,其所有匹配边的和在所有完备匹配中最大,即为求二分图最佳完美匹配。

KM算法(专门用来求解二分图最佳完美匹配):
首先我们先对每个点赋予权值,村民为 x x x 集合,房间为 y y y 集合,则权值分别为 w x i wx_i wxi w y j wy_j wyj
其满足 w e i g h t ( e d g e < i , j > ) ≤ w x i + w y j weight(edge<i,j>) \leq wx_i+wy_j weight(edge<i,j>)wxi+wyj 。初始时,其中一个集合权值全为零,另一个集合中一个点的权值为其所连边里最大边的权重。

据此,整个算法流程为:
1.初始化点权。
2.枚举一个集合 x x x 中的点。
3.对于当前枚举的点,在子图中寻找完备匹配
(子图中所有边 e d g e < i , j > edge<i,j> edge<i,j> 必须满足 w e i g h t ( e d g e < i , j > ) = w x i + w y j weight(edge<i,j>) = wx_i+wy_j weight(edge<i,j>)=wxi+wyj )。
4.若没有找到增广路,则修改部分点权。继续第3步。若找到则继续枚举下一个点直到枚举完。

对此,重点则在于如何修改点权:
每次修改为一个固定值 d d d ,集合 x x x 上位于匹配中的点的点权 + d +d +d ,集合 y y y 上位于匹配中的点的点权 − d -d d 。其余点的点权不变。
那么对于一条边 e d g e < i , j > edge<i,j> edge<i,j> 我们可以分类讨论:
i , j i,j ij均在增广路中, w x i − d + w y j + d wx_i-d+wy_j+d wxid+wyj+d 则点权和与边权关系不变,边与子图关系也不变。
i i i 在, j j j 不在, w x i − d + w y j wx_i-d+wy_j wxid+wyj 则点权和与边权关系有变,则边与子图关系可能变化。
i i i 不在, j j j 在, w x i + w y j + d wx_i+wy_j+d wxi+wyj+d 则若边原来不在子图的修改之后也不在。
综上 d d d 的取值取第二种边的 m i n ( w x i + w y j − w e i g h t ( e d g e < i , j > ) ) min(wx_i+wy_j-weight(edge<i,j>)) min(wxi+wyjweight(edge<i,j>))
所以每次修改点权都要花费 n 2 n^2 n2 的时间单独求一遍 d d d

整个算法时间复杂度就达到 n 4 n^4 n4

#include<cstdio>
#include<iostream>
#include<vector>
#include<cstring>
#include<string>
#include<queue>
#include<algorithm>
using namespace std;
#define NUM 305
#define mst(array,val,type,Count) memset(array,val,sizeof(type)*(Count))
class Graph
{
    public:
        int n, g[NUM][NUM];
        int wx[NUM], wy[NUM];
        bool build()
        {
            if (scanf("%d", &n) == EOF)return false;
            for (int i = 1; i <= n; ++i)
            {
                wx[i] = 0;
                for (int j = 1; j <= n; ++j)
                {
                    scanf("%d", &g[i][j]);
                    wx[i] = max(wx[i], g[i][j]);
                }
            }
            for (int i = 1; i <= n; ++i)wy[i] = 0;
            return true;
        }
};
class Kuhn_Munkres : public Graph
{
    private:
        int y_to[NUM];
        bool vis_x[NUM], vis_y[NUM];
        bool dfs(const int &vertex);
        int KM();
    public:
        void solve()
        {
            while(build())
                printf("%d\n", KM());
        }
} G;
bool Kuhn_Munkres::dfs(const int &vertex)
{
    vis_x[vertex] = true;
    int weight;
    for (int i = 1; i <= n; ++i)
    {
        if (vis_y[i])continue;
        weight = wx[vertex] + wy[i] - g[vertex][i];
        if (weight == 0)
        {
            vis_y[i] = true;
            if (y_to[i] == -1 || dfs(y_to[i]))
            {
                y_to[i] = vertex;
                return true;
            }
        }
    }
    return false;
}
int Kuhn_Munkres::KM()
{
    int d, ans = 0;
    memset(y_to, -1, sizeof(int) * (n + 1));
    for (int i = 1; i <= n; ++i)
    {
        mst(vis_x, false, bool, n + 1), mst(vis_y, false, bool, n + 1);
        while (!dfs(i))
        {
            d = -1;
            for (int j = 1; j <= n; ++j)
            {
                if (!vis_x[j])continue;
                for (int k = 1; k <= n; ++k)
                {
                    if (vis_y[j])continue;
                    d = (d == -1) ? (wx[j] + wy[k] - g[j][k]) : min(d, wx[j] + wy[k] - g[j][k]);
                }
            }
            for (int j = 1; j <= n; ++j)
            {
                if (vis_x[j])wx[j] -= d;
                if (vis_y[j])wy[j] += d;
            }
            mst(vis_x, false, bool, n + 1), mst(vis_y, false, bool, n + 1);
        }
    }
    for (int i = 1; i <= n; ++i)ans += wx[i];
    for (int i = 1; i <= n; ++i)ans += wy[i];
    return ans;
}
int main()
{
    G.solve();
    return 0;
}

优化:
KM算法可以优化至 n 3 n^3 n3 。这时需要用到一个 s l a c k slack slack 数组,用来存与集合 x x x 相对的集合 y y y 中每个节点对应的 d d d 值。所以每次修改的时候的 d d d 值即为 m i n ( s l a c k [ j ] ) min(slack[j]) min(slack[j])
注:每次修改点权时,若是不在增广路中的点的 s l a c k slack slack 也要 − d -d d
又注:该优化仅在随机数据下可以达到 n 3 n^3 n3 的复杂度,所以特地出数据卡还是能卡成 n 4 n^4 n4 的复杂度。

#include<cstdio>
#include<iostream>
#include<vector>
#include<cstring>
#include<string>
#include<queue>
#include<algorithm>
using namespace std;
#define NUM 305
#define INF 0x3f3f3f3f
#define mst(array,val,type,Count) memset(array,val,sizeof(type)*(Count))
class Graph
{
    public:
        int n, g[NUM][NUM];
        int wx[NUM], wy[NUM];
        bool build()
        {
            if (scanf("%d", &n) == EOF)return false;
            for (int i = 1; i <= n; ++i)
            {
                wx[i] = -INF;
                for (int j = 1; j <= n; ++j)
                {
                    scanf("%d", &g[i][j]);
                    wx[i] = max(wx[i], g[i][j]);
                }
            }
            for (int i = 1; i <= n; ++i)wy[i] = 0;
            return true;
        }
};
class Kuhn_Munkres : public Graph
{
    private:
      int y_to[NUM], slack[NUM];
      bool vis_x[NUM], vis_y[NUM];
      bool dfs(const int &vertex);
      int KM();
    public:
        void solve()
        {
            while(build())
                printf("%d\n", KM());
        }
} G;
bool Kuhn_Munkres::dfs(const int &vertex)
{
    vis_x[vertex] = true;
    int weight;
    for (int i = 1; i <= n; ++i)
    {
        if (vis_y[i])continue;
        weight = wx[vertex] + wy[i] - g[vertex][i];
        if (weight == 0)
        {
            vis_y[i] = true;
            if (y_to[i] == -1 || dfs(y_to[i]))
            {
                y_to[i] = vertex;
                return true;
            }
        }
        else
            slack[i] = min(slack[i], weight);
    }
    return false;
}
int Kuhn_Munkres::KM()
{
    int d, ans = 0;
    memset(y_to, -1, sizeof(int) * (n + 1));
    for (int i = 1; i <= n; ++i)
    {
        mst(slack, INF, int, n + 1), mst(vis_x, false, bool, n + 1), mst(vis_y, false, bool, n + 1);
        while (!dfs(i))
        {
            d = INF;
            for (int j = 1; j <= n; ++j)
            {
                if (vis_y[j])continue;
                d = min(d, slack[j]);
            }
            for (int j = 1; j <= n; ++j)
            {
                if (vis_x[j])wx[j] -= d;
                if (vis_y[j])wy[j] += d;
                else slack[j] -= d;
            }
            memset(vis_x, false, sizeof(bool) * (n + 1)), memset(vis_y, false, sizeof(bool) * (n + 1));
        }
    }
    for (int i = 1; i <= n; ++i)ans += wx[i];
    for (int i = 1; i <= n; ++i)ans += wy[i];
    return ans;
}
int main()
{
    G.solve();
    return 0;
}

最后,要想做到稳定的 n 3 n^3 n3 只需要把算法中的dfs改成bfs即可

#include<bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define LLINF 0x3f3f3f3f3f3f3f3f
#define lowbit(x) ((-x) & x)
#define ffor(i, d, u) for (int i = (d); i <= (u); ++i)
#define _ffor(i, u, d) for (int i = (u); i >= (d); --i)
#define mst(array, Num, Kind, Count) memset(array, Num, sizeof(Kind) * (Count))
#define mp(x, y) make_pair(x, y)
#define fi first
#define se second
#define N 305
#define M 1000005
typedef long long ll;
typedef double db;
typedef pair<ll, ll> pll;
typedef pair<int, int> pii;
typedef pair<db, db> pdd;
const db PI = acos(-1);
const ll MO = 1e9 + 7;
const ll Inv2 = (MO + 1) / 2;
const bool debug = true;
template <typename T>
inline void read(T &x)
{
    x=0;char c;T t=1;while(((c=getchar())<'0'||c>'9')&&c!='-');
    if(c=='-'){t=-1;c=getchar();}do(x*=10)+=(c-'0');while((c=getchar())>='0'&&c<='9');x*=t;
}
template <typename T, typename... Args>
inline void read(T &x, Args &... args)
{
    read(x), read(args...);
}
template <typename T>
inline void write(T x)
{
    int len=0;char c[21];if(x<0)putchar('-'),x*=(-1);
    do{++len;c[len]=(x%10)+'0';}while(x/=10);_ffor(i,len,1)putchar(c[i]);
}
int n, pre[N], xto[N], yto[N];
ll wx[N], wy[N], slack[N], g[N][N];
bool visa[N], visb[N];
int q[N], head, tail;
bool check(int cur)
{
    visb[cur] = true;
    if (yto[cur])
    {
        if (!visa[yto[cur]])
            q[++tail] = yto[cur], visa[yto[cur]] = true;
        return false;
    }
    while (cur)
        swap(cur, xto[yto[cur] = pre[cur]]);
    return true;
}
void bfs(int s)
{
    mst(visa, false, bool, n + 1), mst(visb, false, bool, n + 1), mst(slack, 0x3f, ll, n + 1);
    head = 1, tail = 1, q[1] = s, visa[s] = true;
    while (true)
    {
        while (head <= tail)
        {
            int u = q[head++];
            ffor(i, 1, n)
            {
                if (visb[i])
                    continue;
                ll value = wx[u] + wy[i] - g[u][i];
                if (slack[i] >= value)
                {
                    slack[i] = value, pre[i] = u;
                    if (slack[i] == 0 && check(i))
                        return;
                }
            }
        }
        ll d = LLINF;
        ffor(i, 1, n) if (visb[i] == false && slack[i]) d = min(d, slack[i]);
        ffor(i, 1, n)
        {
            if (visa[i]) wx[i] -= d;
            if (visb[i]) wy[i] += d;
            else slack[i] -= d;
        }
        head = 1, tail = 0;
        ffor(i, 1, n) if (!visb[i] && slack[i] == 0 && check(i)) return;
    }
}
ll solve()
{
    mst(xto, 0, int, n + 1), mst(yto, 0, int, n + 1), mst(wy, 0, int, n + 1);
    ffor(i, 1, n)
    {
        wx[i] = -LLINF;
        ffor(j, 1, n) read(g[i][j]), wx[i] = max(wx[i], g[i][j]);
    }
    ffor(i, 1, n) bfs(i);
    ll ans = 0;
    ffor(i, 1, n) ans += g[yto[i]][i];
    return ans;
}
inline int ac()
{
    while (scanf("%d", &n) != EOF)
        write(solve()), putchar('\n');
    return 0;
}
int main()
{
    ac();
    return 0;
}

PS:KM算法应用于解决最佳完美匹配,若不是求完备匹配的最大权匹配,需要把不存在的边都添上,新添的边的边权赋为0即可。
而若是求最小权完备匹配,则把边权全部取相反数即可。最后求出来的结果再取反即可。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值