并查集
支持操作
在 O ( 1 ) O(1) O(1) 时间内
- 将两个集合合并
- 询问两个元素是否在一个集合当中
基本原理
每个集合用一棵树表示,树根的标号就是整个集合的标号。每个节点存储他的父节点 p [ x ] p[x] p[x]
- 如何判断树根
p [ x ] = x p[x]=x p[x]=x 来判断
- 如何求 x x x 的集合编号
w h i l e ( p [ x ] ≠ x ) x = p [ x ] while(p[x]\ne x)\ x=p[x] while(p[x]=x) x=p[x] → \to → 这一步时间复杂度较高
优化:在找根节点的途中,把路径上所有的点都指向根节点 O ( 1 ) O(1) O(1) 路径压缩
- 如何合并两个集合
加边,把左边集合的根连到右边集合,或者反过来
e . g . e.g. e.g. 设 p x p_x px 是 x x x 的集合编号, p y p_y py 是 y y y 的集合编号 p [ x ] = y p[x]=y p[x]=y 就合并了 x , y x,y x,y
Code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> PII;
#define debug(a) cout << #a << " " << a << endl
const int maxn = 1e5 + 7;
const int N = 1e6 + 7, M = N * 2;
const int inf = 0x3f3f3f3f;
const long long mod = 1e9 + 7;
int n, m;
int p[N];
int find(int x) { //返回x的祖宗节点 + 路径压缩
if(p[x] != x) p[x] = find(p[x]);
return p[x];
}
int main() {
// freopen("input.txt", "r", stdin);
// freopen("output.txt", "w", stdout);
// ios::sync_with_stdio(false);
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i++) p[i] = i; //并查集初始化
while(m--) {
char op[2];
int a, b;
scanf("%s%d%d", op, &a, &b);
if(op[0] == 'M') {
p[find(a)] = find(b);
} else {
if(find(a) == find(b)) puts("Yes");
else puts("No");
}
}
return 0;
}
维护额外信息
https://www.acwing.com/problem/content/242/
用每个点到根节点的距离表示它和根结点的关系,由于只有3种动物,所以知道任意两种动物之间的关系,就可以知道第三种关系。
设 x x x 表示一个节点离根节点的距离
x m o d 3 = 1 → x\mod3=1\to xmod3=1→ 可以吃根节点
x m o d 3 = 2 → x\mod3=2\to xmod3=2→ 可以被根节点吃
x m o d 3 = 0 → x\mod3=0\to xmod3=0→ 和根节点是同类
find
函数
int find(int x) {
if(p[x] != x) {
int u = find(p[x]);
d[x] += d[p[x]];
p[x] = u;
}
return p[x];
}
u = find(p[x])
先把父节点及以上压缩到根节点,这时父节点是根节点的一级子节点,x
是根节点的二级子节点。过程中d[p[x]]被更新为父节点到根节点的距离。
d[x] += d[p[x]]; p[x] = u;
先更新边权,再把x
也压到根节点。否则x的父节点到根节点的距离d[p[x]]
没加上就丢失了,所以要把 p[x]
到根节点的距离先存下来
Code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> PII;
#define debug(a) cout << #a << " " << a << endl
const int maxn = 1e5 + 7;
const int N = 1e6 + 7, M = N * 2;
const int inf = 0x3f3f3f3f;
const long long mod = 1e9 + 7;
int n, k;
int p[N], d[N];
int find(int x) {
if(p[x] != x) {
int u = find(p[x]);
d[x] += d[p[x]];
p[x] = u;
}
return p[x];
}
int main() {
scanf("%d%d", &n, &k);
for(int i = 1; i <= n; i++) p[i] = i;
int res = 0;
while(k--) {
int t, x, y;
scanf("%d%d%d", &t, &x, &y); //t表示类别 1是同类,2是互吃
if(x > n || y > n) res++;
else {
int px = find(x), py = find(y); //px,py 在哪个集合当中
if(t == 1) {
if(px == py && (d[x] - d[y]) % 3) { //如果x和y属同一类,他们到根节点的差mod3一定为0,如果不为零,他们一定不同类
res++;
} else if(px != py) { //说明px,py不在一个集合之中
p[px] = py; //让px和py归属同一类
//d[px]=? -> 人为计算决定
//如果它和dy一类 则 (d[x]+?-d[y])mod3==0
//所以 ?=d[y]-d[x]
d[px] = d[y] - d[x];
}
} else {
if(x==y) res++;
else if(px == py && (d[x] - d[y] - 1) % 3) res++;
else if(px != py) {
//x吃y->(d[x]-d[y]-1)%3 == 0
//d[px]=d[y]+1-d[x]
p[px] = py;
d[px] = d[y] + 1 - d[x];
}
}
}
}
printf("%d\n", res);
return 0;
}