#include <stdio.h>
#include <math.h>
//取中位数
int getMid(int s, int e) { return s + (e -s)/2; }
/*
st --> 线段树的指针
index --> 当前节点在线段树中的下标. 初始为 0,因为根节点的下标是0
ss & se --> 当前节点st[index]做表示的区间 [ss .... se]
qs & qe --> 要查询的区间 起始坐标 */
int getSumRecall(int *st, int ss, int se, int qs, int qe, int index)
{
// 如果查询的区间 属于当前节点表示的区间
if (qs <= ss && qe >= se)
return st[index];
// 完全不属于
if (se < qs || ss > qe)
return 0;
// 部分属于
int mid = getMid(ss, se);
return getSumRecall(st, ss, mid, qs, qe, 2*index+1) +
getSumRecall(st, mid+1, se, qs, qe, 2*index+2);
}
/*
st, index, ss 和 se 和函数getSumRecall 相同
i --> 要更新的元素下标.(原始数组中的下标)
diff --> 需要增加的值。 所有包含i的区间都会更新 */
void updateValueRecall(int *st, int ss, int se, int i, int diff, int index)
{
if (i < ss || i > se)
return;
st[index] = st[index] + diff;
if (se != ss)
{
int mid = getMid(ss, se);
updateValueRecall(st, ss, mid, i, diff, 2*index + 1);
updateValueRecall(st, mid+1, se, i, diff, 2*index + 2);
}
}
// 更新i为新的值new_val,主要是调用updateValueRecall
void updateValue(int arr[], int *st, int n, int i, int new_val)
{
if (i < 0 || i > n-1)
{
printf("Invalid Input");
return;
}
int diff = new_val - arr[i];
arr[i] = new_val;
// 更新线段树
updateValueRecall(st, 0, n-1, i, diff, 0);
}
//查询,主要是调用getSumRecall
int getSum(int *st, int n, int qs, int qe)
{
if (qs < 0 || qe > n-1 || qs > qe)
{
printf("Invalid Input");
return -1;
}
return getSumRecall(st, 0, n-1, qs, qe, 0);
}
// 递归来构建区间为 [ss..se] 的线段树 (st[si])
// si 当前节点在线段树中的下标
// return: 当前区间的总和
int constructSTRecall(int arr[], int ss, int se, int *st, int si)
{
// 如果只有一个元素,说明到达了叶子节点
if (ss == se)
{
st[si] = arr[ss];
return arr[ss];
}
// 递归的构建左右区间(子线段树)
int mid = getMid(ss, se);
st[si] = constructSTRecall(arr, ss, mid, st, si*2+1) +
constructSTRecall(arr, mid+1, se, st, si*2+2);
return st[si];
}
/*构建线段树,主要是调用递归函数 constructSTRecall 完成 */
int *constructST(int arr[], int n)
{
// 分配内存
int *st = new int[n * 2];
// 构建线段树
constructSTRecall(arr, 0, n-1, st, 0);
return st;
}
//测试
int main()
{
int arr[] = {1, 3, 5, 7, 9, 11};
int n = sizeof(arr)/sizeof(arr[0]);
// 构建线段树
int *st = constructST(arr, n);
printf("Sum of values in given range = %d\n", getSum(st, n, 1, 3));
// 更新: arr[1] = 10 and 更新相应的区间
updateValue(arr, st, n, 1, 10);
printf("Updated sum of values in given range = %d\n",
getSum(st, n, 1, 3));
return 0;
}
线段树模板(算法)
最新推荐文章于 2024-05-08 21:05:46 发布