在二叉树中找到两个节点的最近公共祖先

在二叉树中找到两个节点的最近公共祖先

题目

给定一棵二叉树的头节点root,以及这棵树中的两个节点node1, node2,返回node1, node2的最近公共祖先节点。

TreeNode

"""
Data Structures And Algorithms
    find lowest ancestor of tow nodes in a binary tree
"""
import random
from collections import deque


class TreeNode():  # pylint: disable=too-few-public-methods
    """Binary Tree Node
    """
    def __init__(self, val):
        self.val = val
        self.left = None
        self.right = None
        self.parent = None

BST

class BST():
    """Binary Search Tree
    """
    def __init__(self):
        self.root = None

    def insert(self, val):
        """insert a val to BST
        """
        def recursive(node, val):
            """recursive insert a val to BST
            """
            if not node:
                return TreeNode(val)

            if val < node.val:
                node.left = recursive(node.left, val)
            elif val > node.val:
                node.right = recursive(node.right, val)

            return node

        self.root = recursive(self.root, val)

    def search(self, val):
        """search value from BST
        """
        def recursive(node, val):
            """recursive search value from BST
            """
            if node is None:
                return None

            if val < node.val:
                return recursive(node.left, val)
            if val > node.val:
                return recursive(node.right, val)
            return node

        return recursive(self.root, val)

思路1

后序遍历,假设处理cur节点的左子树返回left,右子树返回right,现在处理cur,

  • cur 是None,或者cur是node1 或node2,返回cur
  • left, right都是None,说明子树上未发现node1, node2,返回None
  • left, right都不是None,说明分别在左右子树上找到了node1, node2,cur就是最近的公共祖先
  • left, right一个为None,一个不为None,假设不为None的那个节点是node,则node子树可能:
    1)发现了node1或node2
    2)发现了node2,node2的公共祖先
    返回node
def lowest_ancestor1(root, node1, node2):
    """get lowest ancestor, method 1
    """
    if root is None or root == node1 or root == node2:
        return root

    left = lowest_ancestor1(root.left, node1, node2)
    right = lowest_ancestor1(root.right, node1, node2)
    if left is not None and right is not None:
        return root

    return left if left is not None else right

思路2

实现一个判断一个节点是不是另一个节点祖先的函数,

  • 如果node1是node2的祖先,返回node1
  • 如果node2是node1的祖先,返回node2
  • 否则node设为root,执行
    • 如果node的左子树是node1, node2的公共祖先,转到左子树
    • 如果node的左子树是node1, node2的公共祖先,转到右子树
  • 直到左右子树都不是node1, node2的公共祖先,则node为所找的节点
def lowest_ancestor2(root, node1, node2):
    """get lowest ancestor, method 2
    """
    def is_ancestor_of(ancestor, node):
        """check if parent is parent of child
        """
        if ancestor is None:
            return False
        if ancestor == node:
            return True
        return (is_ancestor_of(ancestor.left, node) or
                is_ancestor_of(ancestor.right, node))

    if root in (None, node1, node2):
        return root

    if is_ancestor_of(node1, node2):
        return node1
    if is_ancestor_of(node2, node1):
        return node2

    node = root
    while node:
        left, right = node.left, node.right
        if (is_ancestor_of(left, node1) and is_ancestor_of(left, node2)):
            node = node.left
        elif (is_ancestor_of(right, node1) and is_ancestor_of(right, node2)):
            node = node.right
        else:
            return node

    return None

思路3

找到root->node1的路径path1,root->node2的路径path2,找到path1,path2中最后一个相同的node,就是最近公共祖先

def lowest_ancestor3(root, node1, node2):
    """get lowest ancestor, method 3
    """
    def get_path(cur, node, path):
        """get path from root to node
        """
        if cur is None:
            return False

        # add current node
        path.append(cur)

        if cur == node:
            return True
        if get_path(cur.left, node, path):
            return True
        if get_path(cur.right, node, path):
            return True

        path.pop()
        return False

    # get path of two nodes
    path1 = []
    path2 = []
    get_path(root, node1, path1)
    get_path(root, node2, path2)

    i = 0
    while i < min(len(path1), len(path2)) and path1[i] == path2[i]:
        i += 1

    return path1[i-1]

思路4

分层遍历过程中记录各个节点的父节点,然后比较路径

def lowest_ancestor4(root, node1, node2):
    """get lowest ancestor, method 4
    """
    deq = deque()
    deq.append(root)
    parents_dic = {root: None}
    while node1 not in parents_dic or node2 not in parents_dic:
        node = deq.popleft()
        if node.left is not None:
            parents_dic[node.left] = node
            deq.append(node.left)
        if node.right is not None:
            parents_dic[node.right] = node
            deq.append(node.right)

    ancestors = set()
    while node1:
        ancestors.add(node1)
        node1 = parents_dic[node1]
    while node2 not in ancestors:
        node2 = parents_dic[node2]

    return node2

进阶题目

如果查询操作非常频繁,怎样处理?

思路1

上面思路4中,将所有节点的父节点记录下来,查询的时候只需要找到路径,然后比较路径。

class LowestAncestorRecord1():
    """record of lowest ancestor
    """
    def __init__(self, root):
        self.map = {}
        if root is not None:
            self.map[root] = None

        self.set_map(root)

    def set_map(self, node):
        """set map for node
        """
        if node is None:
            return
        if node.left is not None:
            self.map[node.left] = node
        if node.right is not None:
            self.map[node.right] = node

        self.set_map(node.left)
        self.set_map(node.right)

    def query(self, node1, node2):
        """query lowest ancestor or node1 and node2
        """
        path = set()
        while node1 in self.map:
            path.add(node1)
            node1 = self.map[node1]

        while node2 not in path:
            node2 = self.map[node2]

        return node2

思路2

开始的时候就找到任意两个节点node1, node2的最近公共祖先,查询的时候查字典。
对于二叉树中任意一棵子树,假设子树根节点root,则

  • root的所有后代节点根root的最近公共祖先都是root
  • root的左子树中每个节点和root的右子树中每个节点的最近公共祖先都是root
class LowestAncestorRecord2():
    """record of lowest ancestor, method 2
    """
    def __init__(self, root):
        self.map = {}
        self.set_map(root)

    def set_map(self, node):
        """set map for node
        """
        if node is None:
            return

        self.process_node(node.left, node)
        self.process_node(node.right, node)
        self.process_left_right(node)

        self.set_map(node.left)
        self.set_map(node.right)

    def process_node(self, node, ancestor):
        """ process node and node's descendant
        map[(node, node's descendant)] = ancestor
        """
        if node is None:
            return

        self.map[(node, ancestor)] = ancestor
        self.process_node(node.left, ancestor)
        self.process_node(node.right, ancestor)

    def process_left_right(self, node):
        """ map[(node of node's left subtree,
        node of node's right subtree)] = node
        """
        if node is None:
            return

        self.process_left(node.left, node.right, node)
        self.process_left_right(node.left)
        self.process_left_right(node.right)

    def process_left(self, left, right, ancestor):
        """process left
        """
        if left is None:
            return

        self.process_right(left, right, ancestor)
        self.process_left(left.left, right, ancestor)
        self.process_left(left.right, right, ancestor)

    def process_right(self, left, right, ancestor):
        """ process right
        """
        if right is None:
            return

        self.map[(left, right)] = ancestor
        self.process_right(left, right.left, ancestor)
        self.process_right(left, right.right, ancestor)

    def query(self, node1, node2):
        """do query
        """
        if node1 == node2:
            return node1

        if (node1, node2) in self.map:
            return self.map[(node1, node2)]
        return self.map[(node2, node1)]

思路3

查询过程中记录下查过的结果,每次查找先查字典,字典没有再真正搜索

class LowestAncestorRecord3():  # pylint: disable=too-few-public-methods
    """record of lowest ancestor, method 3
    """
    def __init__(self, root):
        self.root = root
        self.map = {}

    def query(self, node1, node2):
        """query lowest ancestor
        """
        if (node1, node2) in self.map:
            return self.map[(node1, node2)]
        if (node2, node1) in self.map:
            return self.map[(node2, node1)]

        ret = lowest_ancestor4(self.root, node1, node2)
        self.map[(node1, node2)] = ret
        return ret

测试

def test_lowest_ancestor(count, test_count):
    """test lowest ancestor
    """
    bst = BST()
    vals = [i for i in range(count)]
    random.shuffle(vals)
    for val in vals:
        bst.insert(val)

    dic = {}
    for i in range(count):
        dic[i] = bst.search(i)

    record1 = LowestAncestorRecord1(bst.root)
    record2 = LowestAncestorRecord2(bst.root)
    record3 = LowestAncestorRecord3(bst.root)

    for __ in range(test_count):
        results = []
        node1 = dic[random.randint(0, count-1)]
        node2 = dic[random.randint(0, count-1)]

        results.append(lowest_ancestor1(bst.root, node1, node2))
        results.append(lowest_ancestor2(bst.root, node1, node2))
        results.append(lowest_ancestor3(bst.root, node1, node2))
        results.append(lowest_ancestor4(bst.root, node1, node2))
        results.append(record1.query(node1, node2))
        results.append(record2.query(node1, node2))
        results.append(record3.query(node1, node2))

        for i in range(1, len(results)):
            if results[i-1].val != results[i].val:
                print(i)
                raise Exception('Error')


if __name__ == '__main__':
    test_lowest_ancestor(10, 1000)
    test_lowest_ancestor(30, 100)
    test_lowest_ancestor(300, 100)
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值