class UnionFind:
def __init__(self, size):
# 每个节点的父节点
self.parent = list(range(size))
# 节点的秩
self.rank = [0] * size
# 查找父亲
def find(self, x):
# if x != self.parent[x]:
# # 路径压缩
# self.parent[x] = self.find(self.parent[x])
# return self.parent[x]
while x != self.parent[x]:
x = self.parent[x]
return x
# 合并
def union(self, x, y):
rootX = self.find(x)
rootY = self.find(y)
if rootX != rootY:
# 按秩合并
if self.rank[rootX] > self.rank[rootY]:
self.parent[rootY] = rootX
elif self.rank[rootX] < self.rank[rootY]:
self.parent[rootX] = rootY
else:
self.parent[rootY] = rootX
self.rank[rootX] += 1
# 使用示例
uf = UnionFind(10) # 假设有10个元素
uf.union(1, 2)
uf.union(2, 3)
uf.union(4, 5)
uf.union(6, 7)
uf.union(1, 4) # 此时1,2,3,4,5属于同一个集合, 6,7一个集合
print(uf.find(1) == uf.find(5)) # 输出True, 因为1和5在同一个集合内