一、基本思路
1、基本概念
belong[x] = a;
if(belong[x] == belong[y])
- 由此,引出并查集操作——用近乎O(1)的复杂度完成上述操作
- 特点:
-
-
- 用树的形式来文虎所有集合——不一定是二叉树,可能有很多子节点
-
-
-
- 对于每个节点,都存储下它的父节点 p[x] ,查询上一个节点的途径
-
-
- 若查某个节点属于哪个集合,则可以一直沿着 p[x] 向上查找,最终找到根节点
-
-
- 注意与Trie树的存储方式不同,每个节点代表的都是一个集合元素,而不是将每个元素进行更细致的拆分
2、原理及问题实现
-
基本原理:
-
- 每个集合都用一棵树来表示,树根的编号就是整个集合的编号(本身数值);每个节点p[x]存储着他的父节点,只有根节点满足p[x] = x。
-
问题1:如何判断是否是根节点
if(p[x] == x)
while(p[x] != x) x = p[x];
-
问题3:如何合并两个集合
- 解决办法:
-
- 直接将其中的一颗树的根节点,搬到另一棵树的某个位置(一般是搬到另一个数的根节点下面),然后将搬移的树的根节点的父节点修改为现在的新树根节点。
- 其中p[x]是原 x 集合的编号,p[y] 是原集合 y 的编号——> p[x] = y
3、问题优化
-
路径压缩算法
- 方法:
-
- 在进行根节点查找的时候,将路径上的每个节点的父节点都指向根节点
二、Java、C语言模板实现
static int findRoot(int x){
if(p[x] != x) {
p[x] = findRoot(p[x]);
}
return p[x];
}
p[findRoot(a)] = findRoot(b);
```c
(1)朴素并查集:
int p[N];
int find(int x)
{
if (p[x] != x) p[x] = find(p[x]);
return p[x];
}
for (int i = 1; i <= n; i ++ ) p[i] = i;
p[find(a)] = find(b);
(2)维护size的并查集:
int p[N], size[N];
int find(int x)
{
if (p[x] != x) p[x] = find(p[x]);
return p[x];
}
for (int i = 1; i <= n; i ++ )
{
p[i] = i;
size[i] = 1;
}
size[find(b)] += size[find(a)];
p[find(a)] = find(b);
(3)维护到祖宗节点距离的并查集:
int p[N], d[N];
int find(int x)
{
if (p[x] != x)
{
int u = find(p[x]);
d[x] += d[p[x]];
p[x] = u;
}
return p[x];
}
for (int i = 1; i <= n; i ++ )
{
p[i] = i;
d[i] = 0;
}
p[find(a)] = find(b);
d[find(a)] = distance;
作者:yxc
链接:https://www.acwing.com/blog/content/404/
来源:AcWing
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
三、例题题解
import java.io.*;
public class Main {
static int N = 100010;
static int[] parent = new int[N];
static int findParent(int x){
if(parent[x] != x) {
parent[x] = findParent(parent[x]);
}
return parent[x];
}
public static void main(String[] args) throws IOException {
StreamTokenizer input = new StreamTokenizer(new BufferedReader(new InputStreamReader(System.in)));
PrintWriter output = new PrintWriter(new OutputStreamWriter(System.out));
input.nextToken();
int n = (int)input.nval;
input.nextToken();
int m = (int) input.nval;
for (int i = 1; i <= n; i++) {
parent[i] = i;
}
for (int i = 0; i < m; i++) {
input.nextToken();
String order = input.sval;
input.nextToken();
int a = (int)input.nval;
input.nextToken();
int b = (int) input.nval;
switch (order){
case "M":
parent[findParent(a)] = findParent(b);
break;
case "Q":
if (findParent(a) == findParent(b)){
output.println("Yes");
}else {
output.println("No");
}
}
}
output.flush();
output.close();
}
}
import java.io.*;
public class Main {
static int N = 100010;
static int[] p = new int[N];
static int[] size = new int[N];
static int findRoot(int x) {
if (x != p[x]) {
p[x] = findRoot(p[x]);
}
return p[x];
}
public static void main(String[] args) throws IOException {
StreamTokenizer input = new StreamTokenizer(new BufferedReader(new InputStreamReader(System.in)));
PrintWriter output = new PrintWriter(new OutputStreamWriter(System.out));
input.nextToken();
int n = (int) input.nval;
input.nextToken();
int m = (int) input.nval;
for (int i = 1; i <= n; i++) {
p[i] = i;
size[i] = 1;
}
for (int i = 0; i < m; i++) {
input.nextToken();
String order = input.sval;
int a, b, rootA, rootB;
switch (order) {
case "C":
input.nextToken();
a = (int) input.nval;
input.nextToken();
b = (int) input.nval;
rootA = findRoot(a);
rootB = findRoot(b);
if (rootA != rootB) {
p[rootA] = rootB;
size[rootB] += size[rootA];
}
break;
case "Q1":
input.nextToken();
a = (int) input.nval;
input.nextToken();
b = (int) input.nval;
rootA = findRoot(a);
rootB = findRoot(b);
if (rootA == rootB) {
output.println("Yes");
}else {
output.println("No");
}
break;
case "Q2":
input.nextToken();
a = (int) input.nval;
rootA = findRoot(a);
output.println(size[rootA]);
break;
}
}
output.flush();
output.close();
}
}
import java.io.*;
public class Main {
static int N = 50010;
static int[] p = new int[N];
static int[] d = new int[N];
static int findRoot(int x){
if (p[x] != x){
int t = findRoot(p[x]);
d[x] += d[p[x]];
p[x] = t;
}
return p[x];
}
public static void main(String[] args) throws IOException {
StreamTokenizer input = new StreamTokenizer(new BufferedReader(new InputStreamReader(System.in)));
PrintWriter output = new PrintWriter(new OutputStreamWriter(System.out));
input.nextToken();
int n = (int) input.nval;
input.nextToken();
int k = (int) input.nval;
int res = 0;
for (int i = 1; i <= n; i++) {
p[i] = i;
}
for (int i = 0; i < k; i++) {
input.nextToken();
int dd = (int) input.nval;
input.nextToken();
int x = (int) input.nval;
input.nextToken();
int y = (int) input.nval;
if (x > n || y > n){
res++;
}else{
int px = findRoot(x);
int py = findRoot(y);
if (dd == 1){
if (px == py && (d[x] - d[y])%3 != 0){
res++;
}
if (px != py){
p[px] = py;
d[px] = d[y] - d[x];
}
}
if (dd == 2){
if (px == py && (d[x] - d[y] - 1)%3 != 0){
res++;
}
if (px != py){
p[px] = py;
d[px] = d[y] - d[x] + 1;
}
}
}
}
output.println(res);
output.flush();
output.close();
}
}