并查集 (Disjoint Set Union, DSU)
主要用于维护、查询无向图的连通性,支持如下 3 个核心查询方法:
- 用 O ( 1 ) O(1) O(1) 的时间复杂度(摊销)判断两个元素是否处于同一连通分支;
- 用 O ( 1 ) O(1) O(1) 的时间复杂度(摊销)将两个元素所属连通分支合并;
- 用 O ( 1 ) O(1) O(1) 的时间复杂度查询当前连通分支数量。
并查集本质上是数组结构表示的多叉树,每个节点记录它的上级节点,每个节点所属的根节点即该节点所在的连通分支;
在每次查询节点所属根节点时,更新该节点到根节点路径上每个节点存储的上级节点为跟节点;
在每一次合并连通分支时,将元素数量少的分支合并到元素数量多的分支中。
__init__(self, n: int)
- 时间复杂度: O ( n ) O(n) O(n)
- 空间复杂度: O ( n ) O(n) O(n)
Parameters
n : int
:并查集中的元素数量
find(self, i: int) -> int
查询第 i i i 个元素所在的连通分支的根节点元素下标。
- 时间复杂度: O ( 1 ) O(1) O(1)
- 空间复杂度: O ( 1 ) O(1) O(1)
Parameters
i : int
:需要查询的元素下标
Returns
int
:第 i 个元素所在连通分支的根节点元素下标
union(self, *idx: int) -> bool
合并 idx 中所有元素的连通分支。
- 时间复杂度: O ( k ) O(k) O(k),其中 k k k 为元素数量
- 空间复杂度: O ( 1 ) O(1) O(1)
Parameters
*idx : int
:所有需要合并的元素
Returns
bool
:是否合并了两个不同的连通分支(只要有任意两个元素在合并前属于不同的连通分支即返回 True)
is_connected(self, i: int, j: int) -> bool
判断第 i 个元素和第 j 个元素是否在同一个连通分支,若在同一个连通分支则返回 True。
- 时间复杂度: O ( 1 ) O(1) O(1)
- 空间复杂度: O ( 1 ) O(1) O(1)
Parameters
i : int
:需要查询的第 1 个元素的下标j : int
:需要查询的第 2 个元素的下标
Returns
bool
:如第 i 个元素和第 j 个元素在同一个连通分支则返回 True,否则返回 False
get_size(self, i: int) -> int
返回第 i 个元素所在连通分支中元素的数量。
- 时间复杂度: O ( 1 ) O(1) O(1)
- 空间复杂度: O ( 1 ) O(1) O(1)
Parameters
i : int
:需要查询元素的下标
Returns
int
:第 i 个元素所在连通分支中元素的数量
array(self, refresh: bool = True) -> List[int]
(@property
)
获取每个元素存储的上级节点。
- 时间复杂度:若
refresh is True
,则时间复杂度为 O ( n ) O(n) O(n),若refresh is False
,则时间复杂度为 O ( 1 ) O(1) O(1) - 空间复杂度: O ( 1 ) O(1) O(1)
Parameters
refresh : bool, default = True
:是否需要将每个元素存储的上级节点刷新为所在连通分支的根节点,如开启则保证相同连通分支的元素在 array 中相同。
Returns
List[int]
:每个元素存储的上级节点
group_num(self) -> int
(@property
)
返回连通分支总数。
- 时间复杂度: O ( 1 ) O(1) O(1)
- 空间复杂度: O ( 1 ) O(1) O(1)
Returns
int
:连通分支总数
max_group_size(self) -> int
(@property
)
返回最大连通分支中包含的元素数量。
- 时间复杂度: O ( 1 ) O(1) O(1)
- 空间复杂度: O ( 1 ) O(1) O(1)
Returns
int
:最大连通分支中包含的元素数量。
from typing import List
class DSU:
"""并查集 (Disjoint Set Union, DSU)
主要用于维护、查询无向图的连通性,支持如下 2 个核心查询方法:
- 用 O(1) 的时间复杂度(摊销)判断两个元素是否处于同一连通分支;
- 用 O(1) 的时间复杂度(摊销)将两个元素所属连通分支合并;
- 用 O(1) 的时间复杂度查询当前连通分支数量。
并查集本质上是数组结构表示的多叉树,每个节点记录它的上级节点,每个节点所属的根节点即该节点所在的连通分支;
在每次查询节点所属根节点时,更新该节点到根节点路径上每个节点存储的上级节点为跟节点;
在每一次合并连通分支时,将元素数量少的分支合并到元素数量多的分支中。
Attributes
----------
_n : int
并查集中的元素数量
_array : List[int]
并查集中每个元素的上级节点(若没有上级节点则为其自身)
_size : List[int]
并查集中每个连通分支的元素数量(仅保证每个连通分支的根节点的元素数量是正确的)
_group_num : int
并查集中连通分支总数
_is_refresh : bool
是否已刷新所有元素所属连通分支
"""
def __init__(self, n: int) -> None:
"""初始化并查集
时间复杂度: O(n)
空间复杂度: O(n)
Parameters
----------
n : int
并查集中的元素数量
"""
self._n: int = n # 并查集中的元素数量
self._array: List[int] = [i for i in range(n)] # 并查集中每个元素的上级节点
self._size: List[int] = [1] * n # 并查集中每个连通分支的元素数量
self._group_num: int = n # 连通分支数量
self._is_refresh: bool = True # 是否已刷新所有元素所属连通分支
def find(self, i: int) -> int:
"""查询第 i 个元素所在的连通分支的根节点元素下标
时间复杂度: O(1)
空间复杂度: O(1)
Parameters
----------
i : int
需要查询的元素下标
Returns
-------
int
第 i 个元素所在连通分支的根节点元素下标
"""
if self._array[i] != i:
# 递归更新从当前节点到根节点路径上的每个元素的上级节点
self._array[i] = self.find(self._array[i])
return self._array[i]
def _union(self, i: int, j: int) -> bool:
"""合并第 i 个元素所在的连通分支和第 j 个元素所在的连通分支
时间复杂度: O(1)
空间复杂度: O(1)
Parameters
----------
i : int
需要合并的第 1 个元素的下标
j : int
需要合并的第 2 个元素的下标
Returns
-------
bool
是否合并了两个不同的连通分支
"""
i, j = self.find(i), self.find(j) # 获取两个元素所在连通分支的根节点
if i != j:
self._is_refresh = False # 合并两个连通分支,导致连通情况已发生变化,待更新
self._group_num -= 1 # 合并两个连通分支,导致连通分支总数减 1
if self._size[i] >= self._size[j]:
self._array[j] = i
self._size[i] += self._size[j]
else:
self._array[i] = j
self._size[j] += self._size[i]
return True
else:
return False
def union(self, *idx: int) -> bool:
"""合并 idx 中所有元素的连通分支
时间复杂度: O(k),其中 k 为元素数量
空间复杂度: O(1)
Parameters
----------
*idx : int
所有需要合并的元素
Returns
-------
bool
是否合并了两个不同的连通分支(只要有任意两个元素在合并前属于不同的连通分支即返回 True)
"""
if len(idx) <= 1:
return False # 如果只有小于等于 1 个元素则不合并
res = False
i = self.find(idx[0])
for j in idx[1:]:
j = self.find(j)
res = res or self._union(i, j) # 只要有任意两个元素属于不同分支,即将返回值置为 True
return res
def is_connected(self, i: int, j: int) -> bool:
"""判断第 i 个元素和第 j 个元素是否在同一个连通分支,若在同一个连通分支则返回 True
时间复杂度: O(1)
空间复杂度: O(1)
Parameters
----------
i : int
需要查询的第 1 个元素的下标
j : int
需要查询的第 2 个元素的下标
Returns
-------
bool
如第 i 个元素和第 j 个元素在同一个连通分支则返回 True,否则返回 False
"""
return self.find(i) == self.find(j)
def get_size(self, i: int) -> int:
"""返回第 i 个元素所在连通分支中元素的数量
时间复杂度: O(1)
空间复杂度: O(1)
Parameters
----------
i : int
需要查询元素的下标
Returns
-------
int
第 i 个元素所在连通分支中元素的数量
"""
return self._size[self.find(i)] # 因为 self._size 仅保证根节点正确,所以需首先查询到元素所属根节点
def _refresh(self) -> None:
"""刷新 self._array 中所有元素所属的连通分支
通过刷新,可以保证所有元素的在 self._array 中保存的均为该连通分支的根节点,从而可以保证相同连通分支的所有元素
在 self._array 中的连通分支编号相同。
这个方法当且仅当在需要获取 self._array 前需要执行。
时间复杂度: O(n)
空间复杂度: O(1)
"""
if self._is_refresh is False: # 如连通状态发生变化,且未刷新所有元素所属连通分支
for i in range(self._n):
self.find(i) # 逐个元素更新它到连通分支根节点的路径
self._is_refresh = True
@property
def array(self, refresh: bool = True) -> List[int]:
"""获取每个元素存储的上级节点
时间复杂度: 若 refresh is True,则时间复杂度为 O(n),若 refresh is False,则时间复杂度为 O(1)
空间复杂度: O(1)
Parameters
----------
refresh : bool, default = True
是否需要将每个元素存储的上级节点刷新为所在连通分支的根节点,如开启则保证相同连通分支的元素在 array 中相同
Returns
-------
List[int]
每个元素存储的上级节点
"""
if refresh is True:
self._refresh()
return self._array
@property
def group_num(self) -> int:
"""返回连通分支总数
时间复杂度: O(1)
空间复杂度: O(1)
Returns
-------
int
连通分支总数
"""
return self._group_num
@property
def max_group_size(self) -> int:
"""返回最大连通分支中包含的元素数量
时间复杂度: O(n)
空间复杂度: O(1)
Returns
-------
int
最大连通分支中包含的元素数量
"""
return max(self._size)