难度:中等。
标签:树,深度优先搜索,动态规划。
暴力递归来做,超时了。
超时代码:
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode() : val(0), left(nullptr), right(nullptr) {}
* TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
* TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
* };
*/
class Solution {
int dfs(TreeNode* cur, int flag){
if(cur == NULL){
return 0;
}
int result = 0;
if(flag){
int k = cur->val + dfs(cur->left, 0) + dfs(cur->right, 0);
result = max(result, k);
}
else{
int k_left = max(dfs(cur->left, 0), dfs(cur->left, 1));
int k_right = max(dfs(cur->right, 0), dfs(cur->right, 1));
result = max(result, k_left + k_right);
}
return result;
}
public:
int rob(TreeNode* root) {
return max(dfs(root, 1), dfs(root, 0));
}
};
记忆化dfs,使用 f [ n o d e ] f[node] f[node]表示不选择当前节点的子树的最大价值,使用 g [ n o d e ] g[node] g[node]表示选择当前节点的子树的最大价值。
正确解法:
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode() : val(0), left(nullptr), right(nullptr) {}
* TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
* TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
* };
*/
class Solution {
unordered_map<TreeNode*, int> f, g;
int dfs(TreeNode* cur, int flag){
int result = 0;
if(flag){
int k_left = 0, k_right = 0;
if(cur->left != NULL){
if(f.find(cur->left) == f.end())f[cur->left] = dfs(cur->left, 0);
k_left = f[cur->left];
}
if(cur->right != NULL){
if(f.find(cur->right) == f.end())f[cur->right] = dfs(cur->right, 0);
k_right = f[cur->right];
}
result = max(result, cur->val + k_left + k_right);
g[cur] = result;
}
else{
int k_left = 0, k_right = 0;
if(cur->left != NULL){
if(f.find(cur->left) == f.end())f[cur->left] = dfs(cur->left, 0);
if(g.find(cur->left) == g.end())g[cur->left] = dfs(cur->left, 1);
k_left = max(f[cur->left], g[cur->left]);
}
if(cur->right != NULL){
if(f.find(cur->right) == f.end())f[cur->right] = dfs(cur->right, 0);
if(g.find(cur->right) == g.end())g[cur->right] = dfs(cur->right, 1);
k_right = max(f[cur->right], g[cur->right]);
}
result = max(result, k_left + k_right);
f[cur] = result;
}
return result;
}
public:
int rob(TreeNode* root) {
dfs(root, 0);
dfs(root, 1);
return max(f[root], g[root]);
}
};
结果:
这比起官方给的代码,多了很多重复无用的代码。简化代码!!!
通过对官方代码的测试观察,发现不需要判断节点是否为NULL,当map没有键值K时,调用 f [ K ] f[K] f[K],map会自动创建一个键值为K,对应值为0的项。如下代码所示:
unordered_map<TreeNode*, int> f, g;
int rob(TreeNode* root) {
f[root] = f[nullptr];
g[root] = 1;
for(auto it = f.begin(); it != f.end(); it++){
if(it->first == NULL){
cout << "NULL " << it->second << endl;
}
else{
cout << it->first->val << ' ' << it->second << endl;
}
}
return max(f[root], g[root]);
}
输入:
[1]
cout 输出:
1 0
NULL 0
因此,可将代码简化为:
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode() : val(0), left(nullptr), right(nullptr) {}
* TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
* TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
* };
*/
class Solution {
unordered_map<TreeNode*, int> f, g;
void dfs(TreeNode* cur){
if(cur == NULL)return;
dfs(cur->left);
dfs(cur->right);
g[cur] = cur->val + f[cur->left] + f[cur->right];
f[cur] = max(f[cur->left], g[cur->left]) + max(f[cur->right], g[cur->right]);
}
public:
int rob(TreeNode* root) {
dfs(root);
return max(f[root], g[root]);
}
};
结果:
继续简化,可以将map去掉,使用结构体来存储子节点的结果,真是太妙了。
正确解法:
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode() : val(0), left(nullptr), right(nullptr) {}
* TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
* TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
* };
*/
class Solution {
struct NodeState{
int selected;
int notselected;
};
NodeState dfs(TreeNode* cur){
if(cur == NULL)return {0, 0};
NodeState l = dfs(cur->left);
NodeState r = dfs(cur->right);
int selected = cur->val + l.notselected + r.notselected;
int notselected = max(l.notselected, l.selected) + max(r.notselected, r.selected);
return {selected, notselected};
}
public:
int rob(TreeNode* root) {
NodeState rootState = dfs(root);
return max(rootState.selected, rootState.notselected);
}
};
结果: