C++线段树简单版本
具体实验细节,年后补充。
#include <iostream>
#include <vector>
#include <queue>
using namespace std;
typedef struct node
{
double value;
int arr_begin;
int arr_end;
node *left;
node *right;
};
vector<double> get_sum(vector<double> arr)
{
vector<double> prepix_sum;
double sum = 0;
prepix_sum.push_back(0);
for(int i = 0;i < arr.size();i++)
{
sum += arr[i];
prepix_sum.push_back(sum);
}
return prepix_sum;
}
double query(vector<double> prepix_sum,
int arr_begin,
int arr_end)
{
return prepix_sum[arr_end + 1] - prepix_sum[arr_begin];
}
node* create(vector<double> prepix_sum,
int arr_begin,
int arr_end)
{
node *n = new node;
n->value = query(prepix_sum, arr_begin, arr_end);
n->arr_begin = arr_begin;
n->arr_end = arr_end;
if(arr_begin == arr_end)
{
n->left = NULL;
n->right = NULL;
}
else
{
int mid = (arr_begin + arr_end) / 2;
n->left = create(prepix_sum, arr_begin, mid);
n->right = create(prepix_sum, mid+1, arr_end);
}
return n;
}
double sum_range(node *n, int arr_begin, int arr_end)
{
double sum = 0;
queue<node*> x_queue;
x_queue.push(n);
while(x_queue.empty() == false)
{
node *p = x_queue.front();
x_queue.pop();
if(arr_begin <= p->arr_begin && arr_end >= p->arr_end)
{
sum += p->value;
}
else
{
if(p->left != NULL){
x_queue.push(p->left);
}
if(p->right != NULL){
x_queue.push(p->right);
}
}
}
return sum;
}
void update(node *n, int idx, double new_val)
{
queue<node*> path;
path.push(n);
while(true){
if(n->arr_begin == idx && n->arr_end == idx){
break;
}
int mid = (n->arr_begin + n->arr_end) / 2;
if(idx <= mid){
path.push(n->left);
n = n->left;
}
else{
path.push(n->right);
n = n->right;
}
}
double diff = new_val - n->value;
n->value = new_val;
while(path.empty() == false){
node *p = path.front();
path.pop();
p->value += diff;
}
}
int main()
{
vector<double> arr;
arr.push_back(8);
arr.push_back(7);
arr.push_back(6);
arr.push_back(5);
arr.push_back(4);
arr.push_back(3);
vector<double> prepix_sum = get_sum(arr);
node *root;
root = create(prepix_sum, 0, arr.size()-1);
cout << sum_range(root, 0, 3) << endl;
update(root,2, 10);
cout << sum_range(root, 0, 3) << endl;
system("pause");
return 0;
}
原理: