/**
* Definition for a binary tree node.
* public class TreeNode {
* int val;
* TreeNode left;
* TreeNode right;
* TreeNode(int x) { val = x; }
* }
*/
class Solution {
public int maxValue(TreeNode root, int k) {
int[] dp = getDP(root, k);
int ans = Integer.MIN_VALUE;
for (int i = 0; i <= k ; i ++) {
ans = Math.max(ans, dp[i]);
}
return ans;
}
public int[] getDP (TreeNode root, int maxCount) {
int[] dp = new int[maxCount + 1];
if (root == null) return dp;
//左右结点的dp结果
int[] ldp = getDP(root.left, maxCount);
int[] rdp = getDP(root.right, maxCount);
int lMax = Integer.MIN_VALUE;
int rMax = Integer.MIN_VALUE;
//当前结点不染色
for (int i = 0; i <= maxCount; i ++) {
lMax = Math.max(lMax, ldp[i]);
rMax = Math.max(rMax, rdp[i]);
}
dp[0] = lMax + rMax;
//当前结点染色个数为i,左右子结点染色个数和为i-1
for (int i = 1; i <= maxCount; i ++) {
for (int j = 0; j < i; j ++) {
dp[i] = Math.max(dp[i], root.val + ldp[j] + rdp[i - 1 - j]);
}
}
return dp;
}
}