public class UnionFind<V> {
private class Node<V> {
V value;
public Node(V v) {
value = v;
}
}
private HashMap<V, Node<V>> nodes;
private HashMap<Node<V>, Node<V>> parents;
private HashMap<Node<V>, Integer> sizeMap;
public UnionFind(List<V> values) {
nodes = new HashMap<>();
parents = new HashMap<>();
sizeMap = new HashMap<>();
for (V cur : values) {
Node<V> node = new Node<>(cur);
nodes.put(cur, node);
parents.put(node, node);
sizeMap.put(node, 1);
}
}
public Node<V> findFather(Node<V> cur) {
Stack<Node<V>> path = new Stack<>();
while (cur != parents.get(cur)) {
path.push(cur);
cur = parents.get(cur);
}
while (!path.isEmpty()) {
parents.put(path.pop(), cur);
}
return cur;
}
public boolean isSameSet(V a, V b) {
return findFather(nodes.get(a)) == findFather(nodes.get(b));
}
public void union(V a, V b) {
Node<V> aHead = findFather(nodes.get(a));
Node<V> bHead = findFather(nodes.get(b));
if (aHead != bHead) {
int aSetSize = sizeMap.get(aHead);
int bSetSize = sizeMap.get(bHead);
Node<V> big = aSetSize >= bSetSize ? aHead : bHead;
Node<V> small = big == aHead ? bHead : aHead;
parents.put(small, big);
sizeMap.put(big, aSetSize + bSetSize);
sizeMap.remove(small);
}
}
public int size() {
return sizeMap.size();
}
}
模版
public static class UnionFind {
int[] parent;
int[] size;
int sets;
int[] stack;
public UnionFind(int N) {
parent = new int[N];
size = new int[N];
stack = new int[N];
sets = N;
for (int i = 0; i < N; i++) {
parent[i] = i;
size[i] = 1;
}
}
private int find(int i) {
int top = 0;
while (i != parent[i]) {
stack[top++] = i;
i = parent[i];
}
for (top--; top >= 0; top--) {
parent[stack[top]] = i;
}
return i;
}
public void union(int i, int j) {
int f1 = find(i);
int f2 = find(j);
if (f1 != f2) {
int max = size[f1] > size[f2] ? f1 : f2;
int min = max == f1 ? f2 : f1;
size[max] += size[min];
parent[min] = max;
sets--;
}
}
public boolean isSameSet(int i, int j) {
return find(i) == find(j);
}
}