描述
给定一个图,每个边是黑的或白的,求恰好有 k 条黑边的最小生成树。
分析
做法:
令所有黑边加上一个权值c,这样在 kruskal 算法中,由小到大选择边权时黑边会整体向前或者向后平移。通过二分找到一个 c 使得 kruskal 最后恰好用了 k 条黑边,答案就是 ans( 新图最小生成树的值) - k*c 。
注意点:
因为是整数二分,并且 kruskal 边排序时让黑边尽可能靠前,形式上最后可能不存在一个整数 c 使得恰好用了 k 条黑边(具体可见下面样例),这时是要找到一个最大的 c 使得用了大于等于k条黑边。因为 如果取 c 时用了 a (a<k) 条黑边, 取 c-1 时有 b (b>k)条黑边,增量c只减少了1,所以 c - 1时是存在一些黑边和白边权值相等的情况的,因为kruskal中尽量选了黑边,对于黑边和白边权值相等换掉黑边取白边,就可以恰好等于k了。
O ( m l o g c ) c O(mlogc) \ \ c O(mlogc) c是最大边权
代码
#include <bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define pii pair<int,int>
#define pll pair<ll,ll>
#define pli pair<ll,int>
#define Min(a,b,c) min(a,min(b,c))
#define Max(a,b,c) max(a,max(b,c))
typedef long long ll;
typedef unsigned long long ull;
const double pi = 3.141592653589793;
const double eps = 1e-8;
const int INF = 0x3f3f3f3f;
const int N = 50010, M = 100010;
struct node
{
int x, y, z, tag;
bool operator < (const node & b) const
{
return z < b.z;
}
};
int n, m, k, cnta, cntb, ans, fa[N];
node a[M], b[M], edge[M];
int get(int x)
{
if (x == fa[x]) return x;
return fa[x] = get(fa[x]);
}
int check(int c)
{
for (int i = 1; i <= cnta; i++) a[i].z += c;
for (int i = 1, p = 1, q = 1; i <= m; i++)
{
if (a[p].z <= b[q].z)
edge[i] = a[p++];
else
edge[i] = b[q++];
}
for (int i = 0; i < n; i++) fa[i] = i;
int sum = 0;
ans = 0;
for (int i = 1; i <= m; i++)
{
int x = edge[i].x, y = edge[i].y, z = edge[i].z, tag = edge[i].tag;
int rx = get(x), ry = get(y);
if (rx == ry) continue;
fa[rx] = ry;
sum += (tag == 0);
ans += z;
}
for (int i = 1; i <= cnta; i++) a[i].z -= c;
return sum;
}
int main()
{
int tt = 0;
while(scanf("%d%d%d", &n, &m, &k) != EOF)
{
cnta = cntb = 0;
for (int i = 1; i <= m; i++)
{
int x, y, z, tag;
scanf("%d%d%d%d", &x, &y, &z, &tag);
if (!tag) a[++cnta] = {x, y, z, tag};
else b[++cntb] = {x, y, z, tag};
}
sort(a + 1, a + cnta + 1);
sort(b + 1, b + cntb + 1);
a[cnta + 1].z = INF, b[cntb + 1].z = INF;
/* 这里注释掉,这样写二分最终答案会错误,原因见上述分析
int l = -100, r = 100;
while (l < r)
{
int mid = l + r >> 1;
if (check(mid) <= k) r = mid;
else l = mid + 1;
}
check(l);
*/
int l = -100, r = 100;
while (l < r)
{
int mid = l + r + 1 >> 1;
if (check(mid) >= k) l = mid;
else r = mid - 1;
}
check(l);
printf("Case %d: %d\n", ++tt, ans - k * l);
}
return 0;
}
/*
3 5 1
2 0 7 1
1 0 13 0
0 1 13 0
0 1 6 1
2 0 14 0
*/