题目链接:http://www.lintcode.com/zh-cn/problem/interval-sum-ii/#
class Solution {
class SegmentTreeNode {
public:
int start, end, sum;
SegmentTreeNode *left, *right;
SegmentTreeNode(int start, int end, int sum) {
this->start = start;
this->end = end;
this->sum = sum;
this->left = this->right = NULL;
}
};
public:
/* you may need to use some attributes here */
/*
* @param A: An integer array
*/
Solution(vector<int> A) {
// do intialization if necessary
if (A.empty())return;
root = build(0, A.size() - 1, A);
}
/*
* @param start: An integer
* @param end: An integer
* @return: The sum from start to end
*/
long long query(int start, int end) {
// write your code here
if (start > end)return 0;
return doQuery(root, start, end);
}
/*
* @param index: An integer
* @param value: An integer
* @return: nothing
*/
void modify(int index, int value) {
// write your code here
doModify(root, index, value);
}
private:
SegmentTreeNode* root;
SegmentTreeNode* build(int start, int end, const vector<int>& v)
{
if (start == end)
return new SegmentTreeNode(start, end, v[start]);
int mid = (start + end) / 2;
SegmentTreeNode* node = new SegmentTreeNode(start, end, 0);
node->left = build(start, mid, v);
node->right = build(mid + 1, end, v);
node->sum = node->left->sum + node->right->sum;
return node;
}
int doQuery(SegmentTreeNode* root, int start, int end)
{
if (start > root->end || end < root->start)
return 0;
if (start <= root->start && root->end <= end)
return root->sum;
int mid = (root->start + root->end) / 2;
return doQuery(root->left, start, end) + doQuery(root->right, start, end);
}
void doModify(SegmentTreeNode* root, int index, int value)
{
if (root->start == root->end && root->start == index)
{
root->sum = value;
return;
}
int mid = (root->start + root->end) / 2;
if (index <= mid)
{
doModify(root->left, index, value);
root->sum = root->left->sum + root->right->sum;
}
else
{
doModify(root->right, index, value);
root->sum = root->left->sum + root->right->sum;
}
}
};