以前写过带权并查集的题目,自己琢磨半天判断了很多情况,最后终于能过题了,但是写法之麻烦实在是没法保证再写一次也不会出错。看网上的博客才发现能够有简单许多的方法来搞定带权并查集。
首先就是获取祖先的函数,这个函数中与普通的并查集唯一的不同之处就是更新 dis 数组的部分,dis 中存的为该点到其祖先的距离。
int get_root(int x)
{
if(father[x] == x) return x;
int root = get_root(father[x]);
//回溯时将原来的父亲到祖先的距离加上,那么就可以得到该点到祖先的距离
dis[x] = dis[x] + dis[father[x]];
return father[x] = root; //路径压缩
}
那么接下来就是重要的合并以及判断冲突的部分了。
a → ra
↓x ↓y
b → rb
如图,ra 为 a 的祖先,rb 为 b 的祖先,那么 a 到 ra 的距离为 dis[a],b 到 rb 的距离为 dis[b]。
如果 ra 不等于 rb ,则表示 a 与 b 的距离是不确定的,需要通过给定的距离来建立 a 与 b 之间的关系,现在得知 a 与 b 之间的距离是 x ,由向量中的知识我们可以得到,ra 与 rb 之间的距离 y = -dis[a] + x + dis[b],然后将 a 与 b 合并到一个集合,即 father[ra] = rb,这样 a 与 b 之间的关系, a 与(b所在的集合中的其他点)之间的关系以及 b 与(a所在的集合中的其他点)之间的关系就建立完成了。
a → ra == rb
↓x ↗
b
如果 ra 等于 rb,那么就表示 a 与 b 之间的关系已经确定了,这时就需要判断是否有冲突了。现在得知 a 与 b 之间的距离为 x,那么只需要判断 x 等不等于 dis[a] - dis[b] 即可,若等于则皆大欢喜,若不等于则必然是产生冲突了。
一道模板题:HDU-3047
代码:(输入不能cin,会超时...醉了)
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <string>
#include <queue>
#include <stack>
#include <vector>
#include <map>
#include <set>
#include <bitset>
using namespace std;
typedef long long ll;
#define int ll
#define INF 0x3f3f3f3f3f3f3f3f
#define MAXM 100000 + 10
#define MAXN 50000 + 10
const ll mod = 1e9 + 7;
#define P pair<int, int>
#define fir first
#define sec second
int n, m;
int father[MAXN], dis[MAXN];
int ans;
int get_root(int x)
{
if(father[x] == x) return x;
int root = get_root(father[x]);
dis[x] += dis[father[x]];
return father[x] = root;
}
signed main()
{
while(cin >> n >> m) {
memset(dis, 0, sizeof(dis));
ans = 0;
for(int i = 0; i < MAXN; i ++) father[i] = i;
for(int i = 0; i < m; i ++) {
int a, b, x; //cin >> a >> b >> x;
scanf("%lld %lld %lld", &a, &b, &x);
int ra = get_root(a);
int rb = get_root(b);
if(ra == rb) {
if(dis[a] - dis[b] != x)
ans ++;
}
else {
father[ra] = rb;
dis[ra] = -dis[a] + x + dis[b];
}
}
cout << ans << endl;
}
}
/*
The WAM is F**KING interesting .
*/
同样,带权并查集也可以解决种类并查集的问题,在距离数组中取个mod就好了,那就省去了开几倍数组的麻烦,而且还能解决种类数很多的情况。
经典种类并查集:食物链
代码:
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <string>
#include <queue>
#include <stack>
#include <vector>
#include <map>
#include <set>
#include <bitset>
using namespace std;
typedef long long ll;
#define int ll
#define INF 0x3f3f3f3f3f3f3f3f
#define MAXM 100000 + 10
#define MAXN 50000 + 10
const ll mod = 1e9 + 7;
#define P pair<int, int>
#define fir first
#define sec second
int n, m;
int father[MAXN], kind[MAXN];
int ans;
int get_root(int x)
{
if(father[x] == x) return x;
int root = get_root(father[x]);
kind[x] = (kind[x] + kind[father[x]]) % 3;
return father[x] = root;
}
signed main()
{
cin >> n >> m;
for(int i = 1; i <= n; i ++) father[i] = i;
for(int i = 0; i < m; i ++) {
int x, a, b; cin >> x >> a >> b;
x --;
if(a > n || b > n) {
ans ++;
continue;
}
int ra = get_root(a);
int rb = get_root(b);
if(ra == rb) {
if(x != (kind[a] - kind[b] + 3) % 3)
ans ++;
}
else {
father[ra] = rb;
kind[ra] = (kind[b] - kind[a] + x + 3) % 3;
}
}
cout << ans << endl;
}
/*
The WAM is F**KING interesting .
*/