前言
并查集的概念:将编号分别为 1~n 的 n 个对象划分为不相交集合,在每个集合中,选择其中某个元素代表所在集合。
并查集的基本应用是集合问题。在加上权值之后,利用并查集的合并优化和路径压缩,可以对权值所代表的具体应用进行高效的操作。
查询的优化(路径压缩)
普通的查询函数是查询元素 i 所属的集需要搜索路径到根节点并返回,这样显然会浪费很长时间。如果我们在返回时顺便把 i 所属的集改为根节点,那么下次再搜的时候可以直接 O(1) 得到结果。
int find(int x){
if(x == fa[x]) return fa[x];
return fa[x] = find(fa[x]);
}
带权并查集
除了并查集的基本应用 - 处理集合问题。
定义一个权值数组 d[] 把节点 i 到父节点的权值记为 d[i]
带权值的路径压缩
原来的权值 d[],经过压缩之后,更新为 d[]',例如 d[1]' = d[1] + [2] + d[3](也可以是乘、异或等)
int find(int x){
if(x != fa[x]){
int t = fa[x]; // 记录父节点
fa[x] = find(fa[x]); // 路径压缩,递归最后返回的是根节点
d[x] = (d[x] + d[t]) % 3;; // 权值更新为 x 到根节点的权值
}
return fa[x];
}
原来的 d[x] 是点 x 到它的父节点的权值,经过路径压缩后,x 直接指向根节点,d[x] 也更新为 x 到根节点的权值。这是通过递归实现的。
代码中先用 t 记录 x 的原父节点;在递归过程中,最后返回的是根节点;最后将当前节点的权值加上原父节点的权值(注意:经过递归,此时父节点也直接指向根节点,父节点的权值也已经更新为父节点直接到根节点的权值了),就得到当前节点到根节点的权值。
带权值的合并
在合并操作中,把点 x 与点 y 合并,就是把点 x 的根节点 fx 合并到点 y 的根节点 fy。在 fx 和 fy 之间增加权值,这个权值要符合题目的要求。
食物链题解
解题思路
本题有两种解法,一个是带权并查集,另一个是扩展并查集。
带权并查集
题目中的权值关系是指两个动物在食物链上的相对关系。
d(A->B) 表示 A 和 B 的关系,d(A->B) = 0 表示同类,d(A->B) = 1 表示 A 吃 B,d(A->B) = 2 表示 A 被 B 吃。
d(A->B) = 1,d(B->C) = 1,求 d(A->C)。因为 A 吃 B,B 吃 C,那么 C 应该吃 A,得 d(A->C) = 2
d(A->B) = 2,d(B->C) = 2,求 d(A->C)。因为 B 吃 A,C 吃 B,那么 A 应该吃 C,得 d(A->C) = 1
d(A->B) = 0,d(B->C) = 1,求 d(A->C)。因为 A B 同类,B 吃 C,那么 A 吃 C,得 d(A->C) = 1
d(A->C) = (d(A->B) + d(B->C)) % 3
/* 三倍并查集
#include<iostream>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<cstdio>
using namespace std;
const int maxn = 3000100;
int n, m;
int ans;
int fa[maxn];
int a[maxn];
int find(int x){
if(x == fa[x]) return fa[x];
return fa[x] = find(fa[x]);
}
int main(){
cin >> n >> m;
int z, x, y;
for(int i = 1; i <= n * 3; i++) fa[i] = i;
for(int i = 1; i <= m; i++){
cin >> z >> x >> y;
if(x > n || y > n){
ans++;
continue;
}
if(z == 1){
if(find(x) == find(y + n) || find(x + n) == find(y)){
ans++;
continue;
}
else{
fa[find(x)] = find(y);
fa[find(x + n)] = find(y + n);
fa[find(x + 2 * n)] = find(y + 2 * n);
}
}
else{
if(find(x) == find(y) || find(x) == find(y + n)){
ans++;
continue;
}
else{
fa[find(x)] = find(y + 2 * n);
fa[find(x + n)] = find(y);
fa[find(x + 2 * n)] = find(y + n);
}
}
}
cout << ans << '\n';
system("pause");
}
*/
#include<bits/stdc++.h>
using namespace std;
const int maxn = 1001000;
int n, m;
int ans;
int fa[maxn];
int d[maxn];
int find(int x){
if(x != fa[x]){
int t = fa[x];
fa[x] = find(fa[x]);
d[x] = (d[x] + d[t]) % 3;;
}
return fa[x];
}
void merge(int x, int y, int type){
int fax = find(x);
int fay = find(y);
if(fax == fay){
if((type - 1) != (d[x] - d[y] + 3) % 3) ans++;
return;
}
else{
fa[fax] = fay;
d[fax] = (d[y] - d[x] + type - 1) % 3;
}
}
int main(){
cin >> n >> m;
int z, x, y;
for(int i = 1; i <= n; i++) fa[i] = i;
for(int i = 1; i <= m; i++){
cin >> z >> x >> y;
if(x > n || y > n || (z == 2 && x == y)){
ans++;
continue;
}
else merge(x, y, z);
}
cout << ans << '\n';
system("pause");
}