由于O(N^2)
的存储会超内存限制,所以直接记忆化肯定凉了。
只能考虑O(NlogN)
的存储,方法就对每个node都记下,它k=1,2,4,8,…时候的k祖先。比如,可以用f[h][node]
表示node
的2**h
祖先是谁。
需要注意的是,记这个的时候也得做一些优化,直接暴力遍历还是会超时。原理就是:
node的
2
h
2^h
2h距离的祖先就是它
2
h
−
1
2^{h-1}
2h−1祖先的
2
h
−
1
2^{h-1}
2h−1祖先,f[h][node] = f[h-1][f[h-1][node]]
。
然后在query的时候,直接跳到最接近的那个位置h,使得
2
h
<
=
k
<
2
h
+
1
2^h<=k<2^{h+1}
2h<=k<2h+1,然后找f[h][node]
的
k
−
2
h
k-2^h
k−2h祖先即可。
时间复杂度:
- constructor: O ( N l o g N ) O(NlogN) O(NlogN)
- query: O ( Q l o g N ) O(QlogN) O(QlogN)
class TreeAncestor:
def __init__(self, n: int, parent: List[int]):
h = floor(log(n, 2))
self.dist = [parent] + [[-1]*n for _ in range(h)] # dist[i][j] = getKthAncestor(j, 2**i)
for i in range(1, h+1):
for j in range(n):
m = self.dist[i-1][j]
if m != -1:
self.dist[i][j] = self.dist[i-1][m]
# print(self.dist)
@lru_cache()
def getKthAncestor(self, node: int, k: int) -> int:
if node == -1:
return -1
h = floor(log(k, 2))
if k > 2**h:
return self.getKthAncestor(self.dist[h][node], k - 2**h)
return self.dist[h][node]
# Your TreeAncestor object will be instantiated and called as such:
# obj = TreeAncestor(n, parent)
# param_1 = obj.getKthAncestor(node,k)