关闭

codeforces575C——Party

标签: codeforces575CKM算法非递归
429人阅读 评论(0) 收藏 举报
分类:

1、题意:给出n个人,将n个人平均分配到n个岗位中,分别是星期五和星期六,每个人在每个岗位都有各自的兴奋程度,求最大的兴奋程度。
2、分析:这个题拿过来一看就是KM,所以先写一发,O(C(n,n/2)n3)的算法。
就是将人数用枚举分成两半。那么我们暴力KM做二分图最大权匹配。。
然而就是这个结果TLE的代码GG。。。然后我随手加了一个如果超过1.97s就退出的破东西,然后就A了,竟然A了。。。mdzz..

#include <ctime>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
using namespace std;
#define M 25
#define inf 214748364

inline int read(){
    char ch = getchar(); int x = 0, f = 1;
    while(ch < '0' || ch > '9'){
        if(ch == '-') f = -1;
        ch = getchar();
    }
    while('0' <= ch && ch <= '9'){
        x = x * 10 + ch - '0';
        ch = getchar();
    }
    return x * f;
}

namespace KM{
    #define N 25
    int n;
    int A[N], B[N];
    int w[N][N];
    int link[N];
    bool vis[N], vis1[N];
    int mn; 
    int ans;

    inline void init(){
        memset(w, 0, sizeof(w));
    }

    inline bool find(int x){
        vis[x] = true;
        for(int i = 1; i <= n; i ++){
            if(vis1[i]) continue;
            int t = A[x] + B[i] - w[x][i];
            if(!t){
                vis1[i] = true; 
                if(!~link[i] || find(link[i])){
                    link[i] = x;
                    return true;
                }
            }
            else mn = min(mn, t);
        }
        return false;
    }

    inline void km(){
        memset(link, -1, sizeof(link));
        memset(A, 0, sizeof(A));
        memset(B, 0, sizeof(B));
        for(int i = 1; i <= n; i ++){
            for(int j = 1; j <= n; j ++){
                A[i] = max(A[i], w[i][j]);
            }
        }
        for(int i = 1; i <= n; i ++){
            while(1){
                memset(vis, 0, sizeof(vis));
                memset(vis1, 0, sizeof(vis1));
                mn = inf;
                if(find(i)) break;
                int res = 0;
                for(int j = 1; j <= n; j ++){
                    if(vis[j]) A[j] -= mn;
                    if(vis1[j]) B[j] += mn; 
                    res += A[j];
                    res += B[j];
                }
                if(res < ans) return;
            }
        }
    }

    inline int getsum(){
        int ret = 0;
        for(int i = 1; i <= n; i ++) ret += A[i], ret += B[i];
        return ret;
    }
} 

int a[M][M];
int b[M][M];

inline int num1(int x){
    int ret = 0;
    while(x){
        ret += (x & 1);
        x >>= 1;
    }
    return ret;
}

int main(){
    int start = clock();
    int n = read(); KM::n = n;
    for(int i = 1; i <= n; i ++){
        for(int j = 1; j <= n; j ++){
            a[i][j] = read();
        }
    }
    for(int i = 1; i <= n; i ++){
        for(int j = 1; j <= n; j ++){
            b[i][j] = read();
        }
    }
    int tot;
    for(int i = (1 << n) - 1; i >= 0; i --){
        if(num1(i) == (n >> 1)){
            KM::init();
            tot = 0;
            for(int j = 0; j < n; j ++){
                if(i & (1 << j)){
                    tot ++;
                    for(int k = 1; k <= n; k ++){
                        KM::w[tot][k] = a[j + 1][k];
                    }
                }
                else{
                    tot ++;
                    for(int k = 1; k <= n; k ++){
                        KM::w[tot][k] = b[j + 1][k];
                    }
                }
            }
            KM::km();
            int res = KM::getsum();
            KM::ans = max(KM::ans, res);
            if(clock() - start > 1970) break;
        }
    }
    printf("%d\n", KM::ans);
    return 0;
}

然而题目做成这个破样怎么能算已经AC。。实际上还是想明白还是很简单的,KM每次相当于新插入一个节点,然后做一遍增广路,那么我们在爆搜的时候,每选择一个节点就相当于在KM算法中加入一个节点,这样,当我爆搜枚举完的时候,KM算法已经算出了结果,时间复杂度C(n,n/2)n2,另外。。这个题递归版的KM会TLE。。。特地去学习了一下非递归版的KM,一会发学习报告。。贴上AC的代码

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
using namespace std;
#define M 25
#define inf 0x3f3f3f3f

inline int read(){
    char ch = getchar(); int x = 0, f = 1;
    while(ch < '0' || ch > '9'){
        if(ch == '-') f = -1;
        ch = getchar();
    }
    while('0' <= ch && ch <= '9'){
        x = x * 10 + ch - '0';
        ch = getchar();
    }
    return x * f;
}

int n;
int A[M], B[M], lk[M], way[M], mn[M];
int w[M][M];
int a[M][M], b[M][M];
bool vis[M];
int ans; 

inline void km(int x){
    lk[0] = x;
    int j0 = 0;
    memset(vis, 0, sizeof(vis));
    memset(mn, 0x3f, sizeof(mn));
    do{
        vis[j0] = true;
        int i0 = lk[j0], num = inf, j1;
        for(int j = 1; j <= n; j ++){
            if(!vis[j]){
                int t = -w[i0][j] + A[i0] + B[j];
                if(t < mn[j]) mn[j] = t, way[j] = j0;
                if(mn[j] < num) num = mn[j], j1 = j;
            }
        }
        for(int j = 0; j <= n; j ++){
            if(vis[j]) A[lk[j]] -= num, B[j] += num;
            else mn[j] -= num;
        } 
        j0 = j1;
    } while(~lk[j0]);
    do{
        int j1 = way[j0];
        lk[j0] = lk[j1];
        j0 = j1;
    } while(j0);
}

inline void dfs(int x, int t){
    if(x == n + 1){
        ans = max(ans, -B[0]);
        return;
    }
    int oA[M], oB[M], m;
    if(t == (n >> 1)){
        for(int i = 1; i <= n; i ++) w[x][i] = a[x][i];
        km(x);
        dfs(x + 1, t);
        for(int i = 1; i <= n; i ++) w[x][i] = 0;
    }
    else if(x - t - 1 == (n >> 1)){
        for(int i = 1; i <= n; i ++) w[x][i] = b[x][i];
        km(x);
        dfs(x + 1, t + 1);
        for(int i = 1; i <= n; i ++) w[x][i] = 0; 
    }
    else{
        int oa[M], ob[M], olk[M], oway[M];
        memcpy(oa, A, sizeof(A));
        memcpy(ob, B, sizeof(B));
        memcpy(olk, lk, sizeof(lk));
        memcpy(oway, way, sizeof(way));
        for(int i = 1; i <= n; i ++) w[x][i] = a[x][i];
        km(x);
        dfs(x + 1, t);
        for(int i = 1; i <= n; i ++) w[x][i] = 0;
        memcpy(A, oa, sizeof(oa));
        memcpy(B, ob, sizeof(ob));
        memcpy(lk, olk, sizeof(olk));
        memcpy(way, oway, sizeof(oway));
        for(int i = 1; i <= n; i ++) w[x][i] = b[x][i];
        km(x);
        dfs(x + 1, t + 1);
        for(int i = 1; i <= n; i ++) w[x][i] = 0; 
    }
}

int main(){
    //freopen("0input.in", "r", stdin);
    n = read();
    for(int i = 1; i <= n; i ++){
        for(int j = 1; j <= n; j ++){
            a[i][j] = read();
        }
    }
    for(int i = 1; i <= n; i ++){
        for(int j = 1; j <= n; j ++){
            b[i][j] = read();
        }
    } 
    memset(lk, -1, sizeof(lk));
    dfs(1, 0);
    printf("%d\n", ans);
    return 0;
}
0
0

查看评论
* 以上用户言论只代表其个人观点,不代表CSDN网站的观点或立场
    个人资料
    • 访问:91484次
    • 积分:2606
    • 等级:
    • 排名:第14070名
    • 原创:174篇
    • 转载:1篇
    • 译文:0篇
    • 评论:23条
    博客专栏