给你一棵以 root
为根的 二叉树 ,请你返回 任意 二叉搜索子树的最大键值和。
二叉搜索树的定义如下:
- 任意节点的左子树中的键值都 小于 此节点的键值。
- 任意节点的右子树中的键值都 大于 此节点的键值。
- 任意节点的左子树和右子树都是二叉搜索树。
示例 1:
输入:root = [1,4,3,2,4,2,5,null,null,null,null,null,null,4,6] 输出:20 解释:键值为 3 的子树是和最大的二叉搜索树。
示例 2:
输入:root = [4,3,null,1,2] 输出:2 解释:键值为 2 的单节点子树是和最大的二叉搜索树。
示例 3:
输入:root = [-4,-2,-5] 输出:0 解释:所有节点键值都为负数,和最大的二叉搜索树为空。
示例 4:
输入:root = [2,1,3] 输出:6
示例 5:
输入:root = [5,4,8,3,null,6,3] 输出:7
提示:
- 每棵树有
1
到40000
个节点。 - 每个节点的键值在
[-4 * 10^4 , 4 * 10^4]
之间。
map<TreeNode*, bool> mapBTree;
map<TreeNode*, int> mapITree;
map<TreeNode*, int> mapCurMax;
map<TreeNode*, int> mapCurMin;
int find(TreeNode* root)
{
if (nullptr == root)
{
return 0;
}
if (mapITree.count(root)>1)
{
return mapITree[root];
}
int left = find(root->left);
int right = find(root->right);
if (root->left == nullptr || root->right == nullptr)
{
if (root->left == nullptr && root->right == nullptr)
{
mapBTree[root] = true;
mapITree[root] = root->val + left + right;
mapCurMax[root] = root->val;
mapCurMin[root] = root->val;
return root->val + left + right;
}
else if (root->val < mapCurMin[root->right] && mapBTree[root->right] == true && root->left == nullptr && root->right != nullptr)
{
if (root->val < root->right->val)
{
mapBTree[root] = true;
mapITree[root] = root->val + 0 + right;
mapCurMax[root] = root->val>root->right->val ? root->val : root->right->val;
mapCurMin[root] = root->val>root->right->val ? root->right->val : root->val;
return root->val + 0 + right;
}
else
{
mapBTree[root] = false;
mapITree[root] = 0;
return 0;
}
}
else if (root->left != nullptr && root->right == nullptr)
{
if (root->val > mapCurMax[root->left] && mapBTree[root->left] == true && root->val > root->left->val)
{
mapBTree[root] = true;
mapITree[root] = root->val + left + 0;
mapCurMax[root] = root->val>root->left->val ? root->val : root->left->val;
mapCurMin[root] = root->val>root->left->val ? root->left->val : root->val;
return root->val + 0 + left;
}
else
{
mapBTree[root] = false;
mapITree[root] = 0;
return 0;
}
}
}
if (root->left != nullptr && root->right != nullptr && root->val > root->left->val && root->val <root->right->val)
{
//大于左子树最大值,小于右子树最小值
if (root->val < mapCurMin[root->right] && root->val > mapCurMax[root->left] && mapBTree[root->left] == true && mapBTree[root->right] == true)
{
mapBTree[root] = true;
mapITree[root] = root->val + left + right;
mapCurMax[root] = root->right->val;
mapCurMin[root] = root->left->val;
return root->val + left + right;
}
}
mapBTree[root] = false;
mapITree[root] = 0;
return 0;
}
int maxSumBST(TreeNode* root) {
int ret = find(root);
int maxInt = 0;
if (root->left == nullptr || root->right == nullptr)
{
if (root->left == nullptr && root->right != nullptr)
{
mapBTree[root] = false;
mapITree[root] = 0;
}
else if (root->left != nullptr && root->right == nullptr)
{
mapBTree[root] = false;
mapITree[root] = 0;
}
}
for (auto it : mapITree)
{
if (it.second > maxInt &&mapBTree[it.first]==true)
{
maxInt = it.second;
}
}
return maxInt;
}
void test()
{
TreeNode* n1 = new TreeNode(4);
TreeNode* n2 = new TreeNode(3);
TreeNode* n3 = new TreeNode(1);
TreeNode* n4 = new TreeNode(2);
n1->left=n2;
n2->left = n3;
n2->right = n4;
int ret = maxSumBST(n1);
}