先下个定论:
- 种类并查集是一个逐步建立图的过程,图中不同节点的关系是逐步给出的,如果出现冲突,则停止或者跳过。
- 一般来说,双种类并查集开的空间是二倍的朴素并查集,三种类并查集要开的空间是三倍的朴素并查集
- 双种类和三种类并查集的find(查找父节点)操作都相同,都是:
int find(int x){ if(p[x] != x) p[x] = find(p[x]); // 路径压缩 return p[x]; }
关键点在于merge(合并操作)
-
具体说来,与朴素并查集不同的是,种类并查集merge的时候要去单独处理种类的关系。为了方便记忆,可以认为朴素并查集的merge是一句话,双种类并查集的merge是两句,三种类并查集的merge是三句,具体如下:
// 这里的敌人,朋友,x与y同类,x吃y没有做过题的同学可能不太清楚,可以去用题目理解一下。 // 朴素并查集 void merge(int x, int y){ p[find(x)] = p[find(y)]; } // 双种类并查集 void merge(int x, int y){ if(如果两人是朋友) { p[find(x)] = p[find(y)]; p[find(x+n)] = p[find(y+n)]; } else if(如果两人是敌人) { p[find(x+n)] = p[find(y)]; p[find(x)] = p[find(y+n)]; } } // 三种类并查集 void merge(int x, int y){ if(如果x与y是同类) { p[find(x)] = p[find(y)]; p[find(x+n)] = p[find(y+n)]; p[find(x+2n)] = p[find(y+2n)]; } else if(如果x吃y)// x+n表示x吃, x+2n表示x被吃 { p[find(x+n)] = p[find(y)]; p[find(x)] = p[find(y+2n)]; p[find(x+2n)] = p[find(y+n)]; } }
双种类:P1525 【关押罪犯】
AC代码:
#include <bits/stdc++.h>
using namespace std;
const int N = 20010, M = 100010;
struct node{
int a, b, c;
}record[M];
int n,m;
int pa[2*N];
bool cmp(struct node & a, struct node&b){
return a.c > b.c;
}
int find(int x){
if(pa[x] != x) pa[x] = find(pa[x]);
return pa[x];
}
void merge(int x, int y){
pa[find(x)] = find(y);
}
bool query(int x, int y){
return find(x) == find(y);
}
int main(){
scanf("%d%d", &n, &m);
for(int i=1; i<=2 * n; i++)
pa[i] = i;
for(int i=1; i<=m; i++){
int a1, a2, a3;
scanf("%d%d%d", &a1, &a2, &a3);
record[i].a = a1;
record[i].b = a2;
record[i].c = a3;
}
sort(record+1, record+m+1, cmp);
for(int i=1; i<=m; i++){
if(query(record[i].a, record[i].b)) {
cout << record[i].c;
return 0;
}
merge(record[i].a + n, record[i].b);
merge(record[i].a, record[i].b + n);
}
cout << 0;
}
三种类:P2024 [NOI2001] 食物链
#include <bits/stdc++.h>
using namespace std;
const int N = 5e4 + 10;
int p[3*N];
int n,k, ans;
int f, x, y;
int find(int x){
if(p[x] != x) p[x] = find(p[x]);
return p[x];
}
void merge(int x, int y){
p[find(x)] = p[find(y)];
}
bool query(int x, int y){
return find(x) == find(y);
}
int main(){
scanf("%d%d", &n, &k);
for(int i=1; i<=3*n; i++){
p[i] = i;
}
for(int i=1; i<=k; i++){
scanf("%d%d%d", &f, &x, &y);
if(x > n || y > n) {
ans ++;
continue;
}
if(f == 1){
if(query(x, y+n) || query(x, y+2 * n)){
ans ++;
}else {
merge(x, y);
merge(x+n, y+n);
merge(x+2*n, y+2*n);
}
}else if(f == 2){
// x+n 表示x吃,x+2n表示x被吃
if(x == y){
ans ++;
continue;
}
if(query(x+2*n, y) || query(x,y)) {
ans ++;
}else {
merge(y, x+n);
merge(y+2*n, x);
merge(y+n,x+2*n);
}
}
}
cout << ans;
}