http://hi.baidu.com/roba/blog/item/c842fdfac10d24dcb48f31d7.html
roba的解说
题目链接
题目大意就是说,给N个点,每个点都有一个权值,某些点之间存在有向边,有向边的权值为它的两端点权值之和。现在要求从中选出K条边,使得没有任何两条边共头或共尾,问K条边的权值和最小为多少。数据范围大的吓人,N<=10^4,边数M<=10^6,但是K<=100。 容易想到一个网络流(or 匹配?)的做法,就是把每个点p拆成两个(p,p'),然后对于原图中的边(u,v),在新图中加边(u,v'),权值不变。这样问题就转化为在新的二分图中求出一个包含K条边的最小带权匹配。这个问题可以转化成最小费用流解决:添加源和汇,源到所有左子集的点添加容量为1,费用为0的边;所有右子集到汇添加容量为1,费用为0的边;中间的边容量为1,费用为原权值。然后求一个流值为K的最小费用流。 然而常用的最小费用流算法复杂度太高,比如用Bellman-Ford找最小费用路,那么复杂度就高达O(MNK),显然不可接受。到网上四处去问, Jin Bin大牛给了一个 ahyangyi的解答,是贪心一个初始流,先选出尽可能小的K条边,另外 Adrian Kuegel给了一个 非常漂亮的解法,简述如下: 朴素的费用流算法没有利用到题目中“边的权值是它的两端点权值之和”这一条件。考虑一下类似匈牙利算法找交错链的过程,实际上,在这个条件下,可以发现每次增广后增加的费用就是新加入匹配的两个点的权值。比如当前的匹配是(b,a),现在发现一条交错链(c,a),(a,b),(b,d),则增广后新的匹配变为(c,a),(b,d),增加的权值就是c点和d点的权值。由于这个特点,我们可以在每次增广的时候,枚举左子集中当前未匹配的点作为起点,用BFS或DFS查找交错链,记录下每条交错链的起点和终点的权值和,在所有这些权值和中取最小的一对进行增广。注意如果把点按权值排序后,每个点就只须访问一次,故每次增广的时间复杂度就是遍历图的O(M+N),共需增广K次,故总的复杂度为O((M+N)*K)。 |
OK,然后是代码
#include <iostream>
#include <algorithm>
#include <cstring>
#include <string>
#include <cstdio>
#include <cmath>
#include <queue>
#include <map>
#include <set>
#define eps 1e-5
#define MAXN 11111
#define MAXM 1111111
#define INF 1000000007
using namespace std;
int k, n, m;
struct P
{
int w, id;
bool operator <(const P &a)const
{
return w < a.w;
}
}p[MAXN];
struct EDGE
{
int v, next;
}edge[MAXM];
int head[MAXN], e;
int id[MAXN], used[MAXN], lx[MAXN], ly[MAXN], fa[MAXN], ha[MAXN];
int ans;
void init()
{
e = 0;
ans = 0;
memset(head, -1, sizeof(head));
memset(used, 0, sizeof(used));
memset(lx, 0, sizeof(lx));
memset(ly, 0, sizeof(ly));
}
void add(int x, int y)
{
edge[e].v = y;
edge[e].next = head[x];
head[x] = e++;
}
void dfs(int u, int f)
{
for(int i = head[u]; i != -1; i = edge[i].next)
{
int v = edge[i].v;
if(!fa[v])
{
fa[v] = u;
if(!ly[v]) ha[v] = f;
else dfs(ly[v], f);
}
}
}
bool find()
{
memset(fa, 0, sizeof(fa));
for(int i = 1; i <= n; i++)
if(!used[i]) dfs(i, i);
int mi = INF, pos = -1;
for(int i = 1; i <= n; i++)
if(!ly[i] && fa[i])
if(mi > p[i].w + p[ha[i]].w)
{
mi = p[i].w + p[ha[i]].w;
pos = i;
}
if(pos == -1) return 0;
ans += mi;
int u;
for(u = pos; fa[u] != ha[pos];)
{
ly[u] = fa[u];
int tmp = lx[fa[u]];
lx[fa[u]] = u;
u = tmp;
}
ly[u] = ha[pos];
lx[ha[pos]] = u;
used[ha[pos]] = 1;
return 1;
}
int main()
{
int T;
scanf("%d", &T);
while(T--)
{
scanf("%d%d%d", &k, &n, &m);
for(int i = 1; i <= n; i++) scanf("%d", &p[i].w), p[i].id = i;
sort(p + 1, p + n + 1);
for(int i = 1; i <= n; i++) id[p[i].id] = i;
int u, v;
init();
for(int i = 1; i <= m; i++)
{
scanf("%d%d", &u, &v);
add(id[u], id[v]);
}
int flag = 1;
for(int i = 1; i <= k; i++)
if(!find()){flag = 0; break;}
if(!flag) printf("NONE\n");
else printf("%d\n", ans);
}
return 0;
}