Leetcode Count of Range Sum ,本题如果用o(n^2)的方法,得出结果应该是相当简单的,但是有更优的算法。
本题的关键是使每次寻找满足条件的搜索代价最小,那就想到了搜索树了,如果使用平衡二叉搜索树,每次寻找的代价为o(log(n))。但是存在一个问题,如何在不遍历所有节点的情况下而保证节点的值在下次访问时是更新后的呢?
我们使用子树的树根记录整棵子树需要更新的值,并且在下次访问时,携带更新值去遍历子树的子树,而不遍历的子树我们只是更新其的累加值,这样就保证了在不遍历子树的情况下,在下次遍历时得到的值是更新后的值。
算法解析
注:本算法只实现了二叉搜索树,没有实现平衡二叉搜索树,如果使用平衡二叉搜索树算法效率应该更高。
结点插入
- 更新遍历节的值,并且重置其累加值为0
- 判断其与要插入的值的关系,继续搜索或者插入节点,并且更新没有访问的子树的树根的累加值。
结点搜索
- 如果节点为空,则返回0.
- 更新遍历的节点的值,保存其累加值后,重置其累加值
- 如果节点在bounder与upper之间,结果加1,并且搜索左、右子树,返回1 + 左子树计数 + 右子树计数。
- 如果节点值小于bounder,则更新左子树的累加值,并搜索右子树,并返回右子树的计数。
- 如果节点值大于upper,则更新右子树的累加值,并搜索左子树,并返回左子树的计数。
相关算法代码:
#include<iostream>
#include<vector>
using namespace std;
struct Node {
long long int val;
int acc;
struct Node* left;
struct Node* right;
Node(int val): val(val), acc(0), left(NULL), right(NULL){}
};
/**
* This method using a binary search tree to record the sum of [k...n] and
* k belongs to {1...n}, every time search the tree, we need to update the value
* of the node according to the node "acc" field. What's more, the "acc" of
* parent node should propagate to its children.
*/
class Solution {
public:
int countRangeSum(vector<int>& nums, int lower, int upper) {
Node* root = NULL;
int re = 0;
// Count the result of [1....i], which must include nums[i]
for (int i = 0; i < nums.size(); i++) {
// Update the tree
if (root == NULL) {
root = new Node(nums[i]);
} else {
root->acc += nums[i];
insert(root, nums[i]);
}
// Update the result.
re += search(root, 0, lower, upper);
}
return re;
}
int search(Node* root, int val, int lower, int upper) {
if (root == NULL) {
return 0;
}
// Update value of the node according to it's acc and the val
int acc = root->acc + val;
int re = 0;
root->val += acc;
root->acc = 0;
// If the node full fill the condition, search all children
if (root->val >= lower && root->val <= upper) {
re++;
re += search(root->left, acc, lower, upper);
re += search(root->right, acc, lower, upper);
} else {
// If the node not full fill condition
// Propagate the acc to its children
if (root->val < lower) {
if (root->left != NULL) {
root->left->acc += acc;
}
re += search(root->right, acc, lower, upper);
} else if (root->val > upper) {
if (root->right != NULL) {
root->right->acc += acc;
}
re += search(root->left, acc, lower, upper);
}
}
return re;
}
void insert(Node* root, int val) {
// update the nodes value
int acc = root->acc;
root->val += acc;
root->acc = 0;
// Search the place which appropriate the val
// Propagate the acc to its children
if (val < root->val) {
if (root->left != NULL) {
root->left->acc += acc;
insert(root->left, val);
} else {
root->left = new Node(val);
}
if (root->right != NULL) {
root->right->acc += acc;
}
} else {
if (root->right != NULL) {
root->right->acc += acc;
insert(root->right, val);
} else {
root->right= new Node(val);
}
if (root->left != NULL) {
root->left->acc += acc;
}
}
}
};
int main(int argc, char* argv[]) {
Solution so;
vector<int> test;
for (int i = 1; i < argc - 2; i++) {
test.push_back(atoi(argv[i]));
}
int re = so.countRangeSum(test, atoi(argv[argc - 2]), atoi(argv[argc - 1]));
cout<<"result: "<<re<<endl;
return 0;
}
测试:./a.out -2 5 -1 -2 2
输出:result: 3