并查集的代码实现
接口文件:UnionFind.java
/**
* 并查集
*
* @author whx
* @version 2018/9/1
*/
public interface UnionFind {
/**
* 获取并查集的大小
*
* @param
* @return int
* @author whx
* @version 2018/9/1
*/
int getSize();
/**
* 查询两个元素是否相连
*
* @param p
* @param q
* @return boolean
* @author whx
* @version 2018/9/1
*/
boolean isConnected(int p, int q);
/**
* 将两个元素并在一起
*
* @param p
* @param q
* @return void
* @author whx
* @version 2018/9/1
*/
void unionElements(int p, int q);
}
1、版本01
UnionFind1.java
/**
* 并查集
* 第一个版本,Quick Find,数组模拟实现
*
* @author whx
* @version 2018/9/1
*/
public class UnionFind1 implements UnionFind{
private int[] id;//存储当前位置元素的id
public UnionFind1(int size) {
id = new int[size];
for (int i = 0; i < size; i++) {
id[i] = i;
}
}
@Override
public int getSize() {
return id.length;
}
/**
* 查找元素p所对应的集合编号
*
* @param p
* @return int
* @author whx
* @version 2018/9/1
*/
private int find(int p){
if(p < 0 || p > id.length){
throw new IllegalArgumentException("p is out of bound.");
}
return id[p];
}
@Override
public boolean isConnected(int p, int q) {
return find(p) == find(q);
}
@Override
public void unionElements(int p, int q) {
int pId = find(p);
int qId = find(q);
if (pId == qId){
return;
}
for (int i = 0; i < id.length; i++) {
if(id[i] == pId){
id[i] = qId;
}
}
}
}
2、版本02
UnionFind2.java
/**
* 并查集
* 第二个版本,Quick Union,数组实现
*
* @author whx
* @version 2018/9/1
*/
public class UnionFind2 implements UnionFind{
private int[] parent;//存储当前位置元素的根节点
public UnionFind2(int size) {
parent = new int[size];
for (int i = 0; i < size; i++) {
parent[i] = i;
}
}
@Override
public int getSize() {
return parent.length;
}
/**
* 查找元素p所对应的根节点
*
* @param p
* @return int
* @author whx
* @version 2018/9/1
*/
private int find(int p){
if(p < 0 || p > parent.length){
throw new IllegalArgumentException("p is out of bound.");
}
while (p != parent[p]){
p = parent[p];
}
return p;
}
@Override
public boolean isConnected(int p, int q) {
return find(p) == find(q);
}
@Override
public void unionElements(int p, int q) {
int pRoot = find(p);
int qRoot = find(q);
if (pRoot == qRoot){
return;
}
parent[pRoot] = qRoot;
}
}
3、版本03
/**
* 并查集
* 第三个版本,基于size大小的优化,数组实现
*
* @author whx
* @version 2018/9/1
*/
public class UnionFind3 implements UnionFind{
private int[] parent;//存储当前位置元素的根节点
private int[] quantity;//表示以当前元素为根节点的集合中元素个数
public UnionFind3(int size) {
parent = new int[size];
quantity = new int[size];
for (int i = 0; i < size; i++) {
parent[i] = i;
quantity[i] = 1;
}
}
@Override
public int getSize() {
return parent.length;
}
/**
* 查找元素p所对应的根节点
*
* @param p
* @return int
* @author whx
* @version 2018/9/1
*/
private int find(int p){
if(p < 0 || p > parent.length){
throw new IllegalArgumentException("p is out of bound.");
}
while (p != parent[p]){
p = parent[p];
}
return p;
}
@Override
public boolean isConnected(int p, int q) {
return find(p) == find(q);
}
@Override
public void unionElements(int p, int q) {
int pRoot = find(p);
int qRoot = find(q);
if (pRoot == qRoot){
return;
}
//根据以当前元素为根节点的集合中元素个数的大小来合并
//将以当前元素为根节点的集合中元素个数小的合并到大的集合上面
if(quantity[pRoot] < quantity[qRoot]){
parent[pRoot] = qRoot;
quantity[qRoot] += quantity[pRoot];
}else {
parent[qRoot] = pRoot;
quantity[pRoot] += quantity[qRoot];
}
}
}
4、版本04
UnionFind4.java
/**
* 并查集
* 第四个版本,基于rank大小的优化,数组实现
*
* @author whx
* @version 2018/9/1
*/
public class UnionFind4 implements UnionFind{
private int[] parent;//存储当前位置元素的根节点
private int[] rank;//表示以当前元素为根节点的高度
public UnionFind4(int size) {
parent = new int[size];
rank = new int[size];
for (int i = 0; i < size; i++) {
parent[i] = i;
rank[i] = 1;
}
}
@Override
public int getSize() {
return parent.length;
}
/**
* 查找元素p所对应的根节点
*
* @param p
* @return int
* @author whx
* @version 2018/9/1
*/
private int find(int p){
if(p < 0 || p > parent.length){
throw new IllegalArgumentException("p is out of bound.");
}
while (p != parent[p]){
p = parent[p];
}
return p;
}
@Override
public boolean isConnected(int p, int q) {
return find(p) == find(q);
}
@Override
public void unionElements(int p, int q) {
int pRoot = find(p);
int qRoot = find(q);
if (pRoot == qRoot){
return;
}
//根据以当前元素为根节点的高度来合并
//将以当前元素为根节点的高度低的合并到高的集合上面
if(rank[pRoot] < rank[qRoot]){
parent[pRoot] = qRoot;
}else if(rank[pRoot] > rank[qRoot]) {
parent[qRoot] = pRoot;
}else {
parent[pRoot] = qRoot;
rank[qRoot] += 1;
}
}
}
5、版本05
UnionFind5.java
/**
* 并查集
* 第五个版本,路径压缩优化,数组实现
*
* @author whx
* @version 2018/9/1
*/
public class UnionFind5 implements UnionFind{
private int[] parent;//存储当前位置元素的根节点
private int[] rank;//表示以当前元素为根节点的高度
public UnionFind5(int size) {
parent = new int[size];
rank = new int[size];
for (int i = 0; i < size; i++) {
parent[i] = i;
rank[i] = 1;
}
}
@Override
public int getSize() {
return parent.length;
}
/**
* 查找元素p所对应的根节点
*
* @param p
* @return int
* @author whx
* @version 2018/9/1
*/
private int find(int p){
if(p < 0 || p > parent.length){
throw new IllegalArgumentException("p is out of bound.");
}
while (p != parent[p]){
//路径压缩
parent[p] = parent[parent[p]];
p = parent[p];
}
return p;
}
@Override
public boolean isConnected(int p, int q) {
return find(p) == find(q);
}
@Override
public void unionElements(int p, int q) {
int pRoot = find(p);
int qRoot = find(q);
if (pRoot == qRoot){
return;
}
//根据以当前元素为根节点的高度来合并
//将以当前元素为根节点的高度低的合并到高的集合上面
if(rank[pRoot] < rank[qRoot]){
parent[pRoot] = qRoot;
}else if(rank[pRoot] > rank[qRoot]) {
parent[qRoot] = pRoot;
}else {
parent[pRoot] = qRoot;
rank[qRoot] += 1;
}
}
}
6、版本06
UnionFind6.java
/**
* 并查集
* 第六个版本,路径压缩优化,数组实现
*
* @author whx
* @version 2018/9/1
*/
public class UnionFind6 implements UnionFind{
private int[] parent;//存储当前位置元素的根节点
private int[] rank;//表示以当前元素为根节点的高度
public UnionFind6(int size) {
parent = new int[size];
rank = new int[size];
for (int i = 0; i < size; i++) {
parent[i] = i;
rank[i] = 1;
}
}
@Override
public int getSize() {
return parent.length;
}
/**
* 查找元素p所对应的根节点
*
* @param p
* @return int
* @author whx
* @version 2018/9/1
*/
private int find(int p){
if(p < 0 || p > parent.length){
throw new IllegalArgumentException("p is out of bound.");
}
//递归路径压缩
if (p != parent[p]){
parent[p] = find(parent[p]);
}
return parent[p];
}
@Override
public boolean isConnected(int p, int q) {
return find(p) == find(q);
}
@Override
public void unionElements(int p, int q) {
int pRoot = find(p);
int qRoot = find(q);
if (pRoot == qRoot){
return;
}
//根据以当前元素为根节点的高度来合并
//将以当前元素为根节点的高度低的合并到高的集合上面
if(rank[pRoot] < rank[qRoot]){
parent[pRoot] = qRoot;
}else if(rank[pRoot] > rank[qRoot]) {
parent[qRoot] = pRoot;
}else {
parent[pRoot] = qRoot;
rank[qRoot] += 1;
}
}
}
测试代码
Main.java
import java.util.Random;
/**
* @author whx
* @version 2018/9/1
*/
public class Main {
public static void main(String[] args) {
int n = 1000000;
int size = 1000000;
UnionFind1 unionFind1 = new UnionFind1(size);
double testUnionFind1Time = testUnionFind(unionFind1, n);
System.out.println("UnionFind1: "+testUnionFind1Time+" s");
UnionFind2 unionFind2 = new UnionFind2(size);
double testUnionFind2Time = testUnionFind(unionFind2, n);
System.out.println("UnionFind2: "+testUnionFind2Time+" s");
UnionFind3 unionFind3 = new UnionFind3(size);
double testUnionFind3Time = testUnionFind(unionFind3, n);
System.out.println("UnionFind3: "+testUnionFind3Time+" s");
UnionFind4 unionFind4 = new UnionFind4(size);
double testUnionFind4Time = testUnionFind(unionFind4, n);
System.out.println("UnionFind4: "+testUnionFind4Time+" s");
UnionFind5 unionFind5 = new UnionFind5(size);
double testUnionFind5Time = testUnionFind(unionFind5, n);
System.out.println("UnionFind5: "+testUnionFind5Time+" s");
UnionFind6 unionFind6 = new UnionFind6(size);
double testUnionFind6Time = testUnionFind(unionFind6, n);
System.out.println("UnionFind6: "+testUnionFind6Time+" s");
}
/**
* 测试并查集性能
*
* @param unionFind
* @param n
* @return double
* @author whx
* @version 2018/9/1
*/
public static double testUnionFind(UnionFind unionFind, int n){
int size = unionFind.getSize();
Random random = new Random();
long startTime = System.nanoTime();
for (int i = 0; i < n; i++) {
int a = random.nextInt(size);
int b = random.nextInt(size);
unionFind.unionElements(a,b);
}
for (int i = 0; i < n; i++) {
int a = random.nextInt(size);
int b = random.nextInt(size);
unionFind.isConnected(a,b);
}
long endTime = System.nanoTime();
return (endTime - startTime) / 1e9;
}
}