Given a binary tree, find the maximum path sum.
For this problem, a path is defined as any sequence of nodes from some starting node to any node in the tree along the parent-child connections. The path must contain at least one node and does not need to go through the root.
For example:
Given the below binary tree,
1 / \ 2 3
Return 6
.
My way to solve it is simple:
for each of the node, suppose it is the highest node in the path, compute the maxPath sum
The Code:
/**
* Definition for a binary tree node.
* public class TreeNode {
* int val;
* TreeNode left;
* TreeNode right;
* TreeNode(int x) { val = x; }
* }
*/
public class Solution {
int max = Integer.MIN_VALUE;
HashMap<TreeNode, Integer> hm = new HashMap<>();
public int maxPathSum(TreeNode root) {
helper(root);
return max;
}
void helper(TreeNode root){
if(root == null) return;
int left = maxPathDown(root.left);
int right = maxPathDown(root.right);
int sum = root.val;
if(left > 0) sum += left;
if(right > 0) sum += right;
if(sum > max) max = sum;
helper(root.left);
helper(root.right);
}
int maxPathDown(TreeNode root){
if(root == null) return 0;
if(hm.containsKey(root)) return hm.get(root);
int ans = root.val + Math.max(Math.max(maxPathDown(root.left),maxPathDown(root.right)), 0);
hm.put(root,ans);
return ans;
}
}
The helper() method compute and update the path sum for each node as the highest node
The maxPathDown() method compute one path down from a node to the leaf and return the max value of the sum it encountered.
HashMap is used to avoid repeating computation.
I believe this solution runs in O(n) time since every Node is visited 2 times(once in helper and once in maxPathdown)
A smarter way to solve it is to do the update while inside the recursion
public class Solution {
int max = Integer.MIN_VALUE;
HashMap<TreeNode, Integer> hm = new HashMap<>();
public int maxPathSum(TreeNode root) {
maxPathDown(root);
return max;
}
int maxPathDown(TreeNode root){
if(root == null) return 0;
int left = maxPathDown(root.left);
int right = maxPathDown(root.right);
int sum = root.val;
if(left > 0) sum += left;
if(right > 0) sum += right;
if(sum > max) max = sum;
return root.val + Math.max(0,Math.max(left,right));
}
}