在二叉树中找到两个节点的最近公共祖先
题目
给定一棵二叉树的头节点root,以及这棵树中的两个节点node1, node2,返回node1, node2的最近公共祖先节点。
TreeNode
"""
Data Structures And Algorithms
find lowest ancestor of tow nodes in a binary tree
"""
import random
from collections import deque
class TreeNode(): # pylint: disable=too-few-public-methods
"""Binary Tree Node
"""
def __init__(self, val):
self.val = val
self.left = None
self.right = None
self.parent = None
BST
class BST():
"""Binary Search Tree
"""
def __init__(self):
self.root = None
def insert(self, val):
"""insert a val to BST
"""
def recursive(node, val):
"""recursive insert a val to BST
"""
if not node:
return TreeNode(val)
if val < node.val:
node.left = recursive(node.left, val)
elif val > node.val:
node.right = recursive(node.right, val)
return node
self.root = recursive(self.root, val)
def search(self, val):
"""search value from BST
"""
def recursive(node, val):
"""recursive search value from BST
"""
if node is None:
return None
if val < node.val:
return recursive(node.left, val)
if val > node.val:
return recursive(node.right, val)
return node
return recursive(self.root, val)
思路1
后序遍历,假设处理cur节点的左子树返回left,右子树返回right,现在处理cur,
- cur 是None,或者cur是node1 或node2,返回cur
- left, right都是None,说明子树上未发现node1, node2,返回None
- left, right都不是None,说明分别在左右子树上找到了node1, node2,cur就是最近的公共祖先
- left, right一个为None,一个不为None,假设不为None的那个节点是node,则node子树可能:
1)发现了node1或node2
2)发现了node2,node2的公共祖先
返回node
def lowest_ancestor1(root, node1, node2):
"""get lowest ancestor, method 1
"""
if root is None or root == node1 or root == node2:
return root
left = lowest_ancestor1(root.left, node1, node2)
right = lowest_ancestor1(root.right, node1, node2)
if left is not None and right is not None:
return root
return left if left is not None else right
思路2
实现一个判断一个节点是不是另一个节点祖先的函数,
- 如果node1是node2的祖先,返回node1
- 如果node2是node1的祖先,返回node2
- 否则node设为root,执行
- 如果node的左子树是node1, node2的公共祖先,转到左子树
- 如果node的左子树是node1, node2的公共祖先,转到右子树
- 直到左右子树都不是node1, node2的公共祖先,则node为所找的节点
def lowest_ancestor2(root, node1, node2):
"""get lowest ancestor, method 2
"""
def is_ancestor_of(ancestor, node):
"""check if parent is parent of child
"""
if ancestor is None:
return False
if ancestor == node:
return True
return (is_ancestor_of(ancestor.left, node) or
is_ancestor_of(ancestor.right, node))
if root in (None, node1, node2):
return root
if is_ancestor_of(node1, node2):
return node1
if is_ancestor_of(node2, node1):
return node2
node = root
while node:
left, right = node.left, node.right
if (is_ancestor_of(left, node1) and is_ancestor_of(left, node2)):
node = node.left
elif (is_ancestor_of(right, node1) and is_ancestor_of(right, node2)):
node = node.right
else:
return node
return None
思路3
找到root->node1的路径path1,root->node2的路径path2,找到path1,path2中最后一个相同的node,就是最近公共祖先
def lowest_ancestor3(root, node1, node2):
"""get lowest ancestor, method 3
"""
def get_path(cur, node, path):
"""get path from root to node
"""
if cur is None:
return False
# add current node
path.append(cur)
if cur == node:
return True
if get_path(cur.left, node, path):
return True
if get_path(cur.right, node, path):
return True
path.pop()
return False
# get path of two nodes
path1 = []
path2 = []
get_path(root, node1, path1)
get_path(root, node2, path2)
i = 0
while i < min(len(path1), len(path2)) and path1[i] == path2[i]:
i += 1
return path1[i-1]
思路4
分层遍历过程中记录各个节点的父节点,然后比较路径
def lowest_ancestor4(root, node1, node2):
"""get lowest ancestor, method 4
"""
deq = deque()
deq.append(root)
parents_dic = {root: None}
while node1 not in parents_dic or node2 not in parents_dic:
node = deq.popleft()
if node.left is not None:
parents_dic[node.left] = node
deq.append(node.left)
if node.right is not None:
parents_dic[node.right] = node
deq.append(node.right)
ancestors = set()
while node1:
ancestors.add(node1)
node1 = parents_dic[node1]
while node2 not in ancestors:
node2 = parents_dic[node2]
return node2
进阶题目
如果查询操作非常频繁,怎样处理?
思路1
上面思路4中,将所有节点的父节点记录下来,查询的时候只需要找到路径,然后比较路径。
class LowestAncestorRecord1():
"""record of lowest ancestor
"""
def __init__(self, root):
self.map = {}
if root is not None:
self.map[root] = None
self.set_map(root)
def set_map(self, node):
"""set map for node
"""
if node is None:
return
if node.left is not None:
self.map[node.left] = node
if node.right is not None:
self.map[node.right] = node
self.set_map(node.left)
self.set_map(node.right)
def query(self, node1, node2):
"""query lowest ancestor or node1 and node2
"""
path = set()
while node1 in self.map:
path.add(node1)
node1 = self.map[node1]
while node2 not in path:
node2 = self.map[node2]
return node2
思路2
开始的时候就找到任意两个节点node1, node2的最近公共祖先,查询的时候查字典。
对于二叉树中任意一棵子树,假设子树根节点root,则
- root的所有后代节点根root的最近公共祖先都是root
- root的左子树中每个节点和root的右子树中每个节点的最近公共祖先都是root
class LowestAncestorRecord2():
"""record of lowest ancestor, method 2
"""
def __init__(self, root):
self.map = {}
self.set_map(root)
def set_map(self, node):
"""set map for node
"""
if node is None:
return
self.process_node(node.left, node)
self.process_node(node.right, node)
self.process_left_right(node)
self.set_map(node.left)
self.set_map(node.right)
def process_node(self, node, ancestor):
""" process node and node's descendant
map[(node, node's descendant)] = ancestor
"""
if node is None:
return
self.map[(node, ancestor)] = ancestor
self.process_node(node.left, ancestor)
self.process_node(node.right, ancestor)
def process_left_right(self, node):
""" map[(node of node's left subtree,
node of node's right subtree)] = node
"""
if node is None:
return
self.process_left(node.left, node.right, node)
self.process_left_right(node.left)
self.process_left_right(node.right)
def process_left(self, left, right, ancestor):
"""process left
"""
if left is None:
return
self.process_right(left, right, ancestor)
self.process_left(left.left, right, ancestor)
self.process_left(left.right, right, ancestor)
def process_right(self, left, right, ancestor):
""" process right
"""
if right is None:
return
self.map[(left, right)] = ancestor
self.process_right(left, right.left, ancestor)
self.process_right(left, right.right, ancestor)
def query(self, node1, node2):
"""do query
"""
if node1 == node2:
return node1
if (node1, node2) in self.map:
return self.map[(node1, node2)]
return self.map[(node2, node1)]
思路3
查询过程中记录下查过的结果,每次查找先查字典,字典没有再真正搜索
class LowestAncestorRecord3(): # pylint: disable=too-few-public-methods
"""record of lowest ancestor, method 3
"""
def __init__(self, root):
self.root = root
self.map = {}
def query(self, node1, node2):
"""query lowest ancestor
"""
if (node1, node2) in self.map:
return self.map[(node1, node2)]
if (node2, node1) in self.map:
return self.map[(node2, node1)]
ret = lowest_ancestor4(self.root, node1, node2)
self.map[(node1, node2)] = ret
return ret
测试
def test_lowest_ancestor(count, test_count):
"""test lowest ancestor
"""
bst = BST()
vals = [i for i in range(count)]
random.shuffle(vals)
for val in vals:
bst.insert(val)
dic = {}
for i in range(count):
dic[i] = bst.search(i)
record1 = LowestAncestorRecord1(bst.root)
record2 = LowestAncestorRecord2(bst.root)
record3 = LowestAncestorRecord3(bst.root)
for __ in range(test_count):
results = []
node1 = dic[random.randint(0, count-1)]
node2 = dic[random.randint(0, count-1)]
results.append(lowest_ancestor1(bst.root, node1, node2))
results.append(lowest_ancestor2(bst.root, node1, node2))
results.append(lowest_ancestor3(bst.root, node1, node2))
results.append(lowest_ancestor4(bst.root, node1, node2))
results.append(record1.query(node1, node2))
results.append(record2.query(node1, node2))
results.append(record3.query(node1, node2))
for i in range(1, len(results)):
if results[i-1].val != results[i].val:
print(i)
raise Exception('Error')
if __name__ == '__main__':
test_lowest_ancestor(10, 1000)
test_lowest_ancestor(30, 100)
test_lowest_ancestor(300, 100)