LCA(Lowest Common Ancestors,最近公共祖先)
在有根树中,找出某两个结点u和v最近的公共祖先。一般有两种方法:树上倍增大跳求LCA和Tarjan离线算法求LCA,本篇主要描述Tarjan离线算法求LCA。
Tarjan算法求LCA基本思路
1、任选一个节点root作为根节点,从根节点开始遍历。
2、遍历当前节点u的所有子节点(邻居节点)v,并置u节点为已访问。
3、若v还有子节点,则继续递归其子节点。
4、并查集合并v到u上面。
5、寻找与当前节点u有关的查询,判断所查询的除了u以外的另一节点u'是否已访问。
6、若u'已经被访问过,则u和u'的LCA就是u'已经被合并到的节点。
Tarjan求LCA代码模版(python、c++)
def find(x:int)->int:#并查集模版
if x !=fa[x]:
fa[x] = find(fa[x])
return fa[x]
def tarjan(i: int, f: int):
vis[i] = 1#标记为已访问
for nxt in nei[i].keys():#搜索所有的子节点
if nxt == f or vis[nxt]: continue#如果已访问或者为父节点,跳过
tarjan(nxt, i)#dfs其子节点
fa[nxt] = i#dfs遍历完子节点后合并节点,即缩步方法
for j, v in query[i]:#离线查找所有与节点i有关的访问
if vis[j] == 1:#判断另一节点是否被访问了,即合并了
lca[v] = find(j)#则其LCA为j当前合并的点
tarjan(0, -1)
unordered_map<TreeNode*, TreeNode*> fa;
unordered_set<TreeNode*> vis;
function<TreeNode*(TreeNode*)> find_fa = [&](TreeNode* node){
if(fa[node] != node) fa[node] = find_fa(fa[node]);
return fa[node];
};
TreeNode* ans = nullptr;
function<void(TreeNode*)> tarjan = [&](TreeNode* node){
vis.insert(node);
fa[node] = node;
if(node->left){
tarjan(node->left);
fa[node->left] = node;
}
if(node->right){
tarjan(node->right);
fa[node->right] = node;
}
if(node == p && vis.count(q)){
ans = find_fa(q);
}
else if(node == q && vis.count(p)){
ans = find_fa(p);
}
};
tarjan(root);
return ans;
相关题目练习
1644. 二叉树的最近公共祖先 II
题目描述
数据集范围
思路
LCA模版题,套模版就行。
代码实现
def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':
fa = {}
vis = set()
query = {p:q,q:p}
ans = None
def find(x:'TreeNode')->'TreeNode':
if fa[x]!=x:
fa[x] = find(fa[x])
return fa[x]
def tarjan(node:'TreeNode',f:'TreeNode'):
vis.add(node)
fa[node] = node
for nxt in node.left,node.right:
if not nxt or nxt == f:continue
tarjan(nxt,node)
fa[nxt] = node
if node in query and query[node] in vis:
nonlocal ans
ans = find(query[node])
tarjan(root,None)
return ans
2846. 边权重均等查询
题目描述
数据集范围
思路
显然对于每个查询的结果即为,两个节点之间的路径长度减去路径中最多的权重值的数目。注意到1<=wi<=26,即树的权值仅有26种状态,那么我们可以用O(26*n)的空间存储每个节点到根节点的每个权值的个数,即存入w[n][26]数组中,然后记录每个节点到根节点的深度h[i],并且找到每次所查询的两个节点的公共祖先,并存入LCA数组当中,然后每个查询(u,v)的结果即为:
代码实现
def minOperationsQueries(self, n: int, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
#求到根节点权重和深度
nei = [{} for _ in range(n)]
for a,b,c in edges:
nei[a][b] = c
nei[b][a] = c
h = [0]*n
w = [[0]*27 for _ in range(n)]
def dfs(i:int,fa:int,deep:int):
for j in nei[i]:
if j==fa: continue
h[j] = deep
weight = nei[i][j]
for k in range(1,27):
w[j][k] = w[i][k]
if k == weight:
w[j][k]+= 1
dfs(j, i, deep+1)
dfs(0,-1,1)
#tarjan
lca = [-1]*len(queries)
vis = [0]*n
fa = [i for i in range(n)]
def find(x):
if x!=fa[x]:
fa[x] = find(fa[x])
return fa[x]
query = [[] for _ in range(n)]
for i,(q1,q2) in enumerate(queries):
query[q1].append((q2,i))
query[q2].append((q1,i))
def tarjan(i:int,f:int):
vis[i] = 1
for nxt in nei[i].keys():
if nxt == f or vis[nxt]:continue
tarjan(nxt, i)
fa[nxt] = i
for j,v in query[i]:
if vis[j] == 1:
lca[v] = find(j)
tarjan(0, -1)
ans = []
for i,(q1,q2) in enumerate(queries):
ancestor = lca[i]
n = h[q1]+h[q2] - 2*h[ancestor]
delete = 0
for k in range(1,27):
delete = max(delete,w[q1][k]+w[q2][k]-2*w[ancestor][k])
ans.append(n-delete)
return ans