用递归的方法先找到o1和o2对应的路径,再寻找该路径下相同的节点,即为最近公共祖先。

代码中需要注意的点:

  • 对于list类型的操作,注意执行完append操作之后不需要再返回list(结合其他语言对指针的理解)
  • 二叉树深度优先搜索的实现
    def lowestCommonAncestor(self , root: TreeNode, o1: int, o2: int) -> int:
        # write code here
        path1, path2 = [], []
        self.dfs(root, path1, o1)
        self.flag = False
        self.dfs(root, path2, o2)
        i = 0
        res = None
        print(path1, path2)                                         
        while i < len(path1) and i < len(path2):
            if path1[i] == path2[i]:
                res = path1[i]
                i += 1
            else:
                break
        return res

    
    def dfs(self, root: TreeNode, path: List[int], o:int):
        if self.flag or not root:
            return 
        path.append(root.val)
        if root.val == o:
            self.flag = True
            return
        self.dfs(root.left, path, o)
        self.dfs(root.right, path, o)
        if self.flag:
            return
        path.pop()

反例:用下面的方法计算出来能通过9/10用例,但内存超出了限额:

import queue
class Solution:
    def lowestCommonAncestor(self , root: TreeNode, o1: int, o2: int) -> int:
        # 由于时间复杂度为O(n),只能以空间换时间,用层序遍历
        num_mapping = dict()
        layer_traverse = queue.Queue()
        layer_traverse.put(root)
        num_mapping[root.val] = str(root.val)
        while layer_traverse:
            l_size = layer_traverse.qsize()
            for _ in range(l_size):
                parent = layer_traverse.get()
                path = num_mapping[parent.val]
                if parent.left:
                    left_path = path + '#'+ str(parent.left.val)
                    num_mapping[parent.left.val] = left_path
                    layer_traverse.put(parent.left)
                if parent.right:
                    right_path = path + '#'+ str(parent.right.val)
                    num_mapping[parent.right.val] = right_path
                    layer_traverse.put(parent.right)  
            if o1 in num_mapping and o2 in num_mapping:
                break
        
        o1_path = str(num_mapping[o1]).split('#')
        o2_path = str(num_mapping[o2]).split('#')
        len_o1_path = len(o1_path)
        len_o2_path = len(o2_path)
        i = 0
        for i in range(len_o1_path if len_o1_path < len_o2_path else len_o2_path):
            if o1_path[i] != o2_path[i]:
                i = i - 1
                break
        return int(o1_path[i])