共有n个节点,n-1条边存储在edges中
import math
class LCA:
def __init__(self, edges):
n = len(edges)+1
self.log = math.ceil(math.log2(n))
self.parent = [[-1 for _ in range(self.log)] for _ in range(n)]
self.depth = [0] * n
self.graph = [[] for _ in range(n)]
for u, v in edges:
self.graph[u].append(v)
self.graph[v].append(u)
self.dfs(0, -1)
def dfs(self, node, par):
for child in self.graph[node]:
if child != par:
self.depth[child] = self.depth[node] + 1
self.parent[child][0] = node
for i in range(1, self.log):
if self.parent[child][i - 1] != -1:
self.parent[child][i] = self.parent[self.parent[child][i - 1]][i - 1]
self.dfs(child, node)
def query(self, u, v):
if self.depth[u] < self.depth[v]:
u, v = v, u
diff = self.depth[u] - self.depth[v]
for i in range(self.log):
if diff & (1 << i):
u = self.parent[u][i]
if u == v:
return u
for i in reversed(range(self.log)):
if self.parent[u][i] != self.parent[v][i]:
u = self.parent[u][i]
v = self.parent[v][i]
return self.parent[u][0]
# 现在重新执行测试代码
edges = [[0, 1], [1, 2], [1, 3], [3, 4]]
lca = LCA(edges)
test_pairs = [(2, 4), (3, 4), (2, 3)]
results = []
for u, v in test_pairs:
ancestor = lca.query(u, v)
results.append((u, v, ancestor))
print(results)