题目:
Given a binary tree, find the maximum path sum.
The path may start and end at any node in the tree.
For example:
Given the below binary tree,
1 / \ 2 3
Return 6
.
此题不需要求产生最大值的路径,而只需要求最大值,故而不难。用递归的Post Order Tranverse即可。注意Java和Python都是按值传递int,故而需要用class封装。
c++版:
class Solution {
public:
int maxPathSum(TreeNode* root) {
int globalM = INT_MIN;
find(root, globalM);
return globalM;
}
int find(TreeNode* root, int& globalM) {
if (root == NULL) {
return 0;
}
int left = find(root->left, globalM);
int right = find(root->right, globalM);
int localML = left + root->val;
int localMR = right + root->val;
int localM = localML;
if (localMR > localM)
localM = localMR;
if (root->val > localM)
localM = root->val;
if (localM > globalM)
globalM = localM;
localML + localMR - root->val > globalM ? globalM = localML + localMR - root->val : globalM = globalM;
return localM;
}
};
Java版:
/**
* Definition for a binary tree node.
* public class TreeNode {
* int val;
* TreeNode left;
* TreeNode right;
* TreeNode(int x) { val = x; }
* }
*/
class Result {
int result;
Result(int x){result = x;}
}
public class Solution {
public int maxPathSum(TreeNode root) {
Result global = new Result(Integer.MIN_VALUE);
find(root, global);
return global.result;
}
int find(TreeNode root, Result global) {
if(root == null)
return 0;
int left = find(root.left, global);
int right = find(root.right, global);
int leftM = left + root.val;
int rightM = right + root.val;
int local = leftM;
if(rightM > local)
local = rightM;
if(root.val > local)
local = root.val;
if(local > global.result)
global.result = local;
global.result = left + right + root.val > global.result ? left + right + root.val : global.result;
return local;
}
}
Python版:
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, x):
# self.val = x
# self.left = None
# self.right = None
import sys
class Result:
def __init__(self, x):
self.result = x
class Solution:
# @param {TreeNode} root
# @return {integer}
def maxPathSum(self, root):
globalM = Result(-sys.maxint)
self.find(root, globalM)
return globalM.result
def find(self, root, globalM):
if root == None:
return 0
left = self.find(root.left, globalM)
right = self.find(root.right, globalM)
local = left + root.val
if right + root.val > local:
local = right + root.val
if root.val > local:
local = root.val
if local > globalM.result:
globalM.result = local
if left + right + root.val > globalM.result:
globalM.result = left + right + root.val
return local