1. 一维树状数组
提出问题:求前缀和,比如给出一个数列1 5 4 2 7..
, 求出前k个元素的累加和,算法如下:
int sum=0;
for(int i=0;i<k;i++){
sum+=arr[i];
}
时间复杂度:O(n)
当修改某个元素arr[k], sum[k] sum[k+1]…这些前缀和的计算全部都要重新计算, 因此引入树状数组:
树状数组引入了分级管理制度且设置了一个管理小组,管理小组中的每个成员都管理一个或多个连续的元素。例如,在数列中有9个元素,分别用a [1], a [2], …, a [9]存储,还设置了一个管理小组c []
c [1]:存储a [1]的值
c [2]:存储c [1]、a [2]的和值,相当于存储a [1]、a [2]的和值
c [3]:存储a [3]的值
c [4]:存储c [2]、c [3]、a [4]的和值,相当于存储a [1]、a [2]、a [3]、a [4]的和值
c [5]:存储a [5]的值
c [6]:存储c [5]、a [6]的和值,相当于存储a [5]、a [6]的和值
c [7]:存储a [7]的值
c [8]:存储c [4]、c [6]、c [7]、a [8]的和值,相当于存储a [1]~a [8]的和值
c [9]:存储a [9]的值
1. 查询前缀和
假设现在查询前缀和sum[k], 则需要c[k]加上c[k]左侧所有子树的根即可
- sum[4]=c[4]
- sum[5]=c[4]+c[5]
- sum[9]=c[8]+c[9]
2. 点更新
点更新指修改一个元素的值,例如对a [5]加上一个数y ,则需要更新该元素的所有祖先节点,即c [5]、c [6]、c [8],令这些节点都加上y 即可,对其他节点都不需要修改, 原因在于只有祖先节点前缀和中包含了当前数据节点
3. 区间长度
从前面的树状数组中可以看到,c[1]指包含类一个数据,c[2]包含了2个数据,c[4]包含了4个数据…c[9]包含了一个数据,这些不同的c[i]包含的数据的个数不一样,数据的个数如何确定呢?
1=001---->末尾0个0—>
2
0
=
1
2^0=1
20=1
2=010---->末尾1个0—>
2
1
=
2
2^1=2
21=2
3=011---->末尾0个0—>
2
0
=
1
2^0=1
20=1
4=100---->末尾2个0—>
2
2
=
4
2^2=4
22=4
…
因此包含的数据个数(区间长度)为 2 二 进 制 末 尾 0 的 个 数 2^{二进制末尾0的个数} 2二进制末尾0的个数
如果方便地求出这个长度,不断右移求出末尾0的个数是一种做法,比较简单方便地一种做法如下:
10100最低位的1及其后面的0构成的数值100, 如何得到110? 做法是(-x)&x
public int lowbit(int i){
return (-i)&i;
}
4. 前驱和后继
直接前驱:c [i ]的直接前驱为c [i -lowbit(i )],即c [i ]左侧紧邻的子树的根
直接后继:c [i ]的直接后继为c [i +lowbit(i )],即c [i ]的父节点
前驱:c [i ]的直接前驱、其直接前驱的直接前驱等,即c [i ]左侧所有子树的根
后继:c [i ]的直接后继,其直接后继的直接后继等,即c [i ]的所有祖先
5. 前缀和查询算法
前i 个元素的前缀和sum[i ]等于c [i ]加上c [i ]的前驱
public int sum(int i){
int sum=0;
while(i>0){
sum+=c[i];
i-=lowbit(i);//直接前驱i-=lowbit
}
return sum;
}
6. 点更新算法
a[i]=a[i]+delta, 只需要更新c[i]和c[i]的祖先节点
public void add(int i,int delta){
while(i<=n){
c[i]+=delta;
i+=lowbit(i);
}
}
树状数组的下标从1开始,不可以从0开始,因为lowbit(0)=0时会出现死循环
6. 区间和查询算法
若求区间和值a [i ]+a [i +1]+…+a [j ],则求解前j 个元素的和值减去前i -1个元素的和值即可,即sum[j ]-sum[i -1]
public int sum(int i,int j){
return sum(j)-sum(i-1);
}
7. 时间复杂度分析
从下标1开始考虑,n表示数组最后一个元素的下标
查询复杂度:O(logn)
更新复杂度:O(logn)
完整代码:
package algorithm;
import java.util.Arrays;
public class TreeArray {
int[] c;
int n;
public TreeArray(int[] arr) {
this.n=arr.length;
c=new int[n+1];
}
/*
* 计算区间长度
*/
public int lowbit(int i){
return (-i)&i;
}
/*
* 前缀和查询
*/
public int sum(int i){
int sum=0;
while(i>0){
sum+=c[i];
i-=lowbit(i);//直接前驱i-=lowbit
}
return sum;
}
/*
* 区间和查询
*/
public int sum(int i,int j){
return sum(j)-sum(i-1);
}
/*
* 点更新
*/
public void add(int i,int delta){
while(i<=n){
c[i]+=delta;
i+=lowbit(i);
}
}
public static void main(String[] args) {
int[] arr= {1,4,6,2,8,10,9,7};
TreeArray tArray=new TreeArray(arr);
for(int i=0;i<arr.length;i++) {
tArray.add(i+1, arr[i]);
}
System.out.println("c数组: "+Arrays.toString(tArray.c));
for(int i=1;i<=arr.length;i++) {
System.out.println("sum["+i+"]="+tArray.sum(i));
}
System.out.println("sum(2,6)="+tArray.sum(2, 6));//4,6,2,8,10
}
}
2. 多维树状数组
1. 查询前缀和
二维数组的前缀和实际上是从数组左上角到当前位置(x , y )矩阵的区间和
public int sum(int x,int y){//求左上角(1,1)位置到位置(x,y)位置的区间和
int sum=0;
for(int i=x;i>0;i-=lowbit(i)){
for(int j=y;j>0;j-=lowbit(j)){
sum+=c[i][j];
}
}
return sum;
}
2. 点更新
arr[i][j]--->arr[i][j]+delta
public void add(int x,int y,int delta){//arr[x][y]加上delta
for(int i=x;i<=n;i+=lowbit(i)){
for(int j=y;j<=n;j+=lowbit(j)){
c[i][j]+=delta;
}
}
}
3. 查询区间和值
求左上角(x1,y1)
到右下角(x2,y2)
子矩阵的区间和
public int sum(int x1,int y1,int x2,int y2){
return sum(x2,y2)-sum(x1-1,y2)-sum(x2,y1-1)+sum(x1-1,y1-1);
}
完整代码:
package algorithm;
import java.util.Arrays;
public class TreeArray {
int[][] c;
int m;
int n;
public TreeArray(int[][] arr) {
this.m=arr.length;
this.n=arr[0].length;
c=new int[m+1][n+1];
}
/*
* 计算区间长度
*/
public int lowbit(int i){
return (-i)&i;
}
/*
* 前缀和查询
*/
public int sum(int x,int y){//求左上角(1,1)位置到位置(x,y)位置的区间和
int sum=0;
for(int i=x;i>0;i-=lowbit(i)){
for(int j=y;j>0;j-=lowbit(j)){
sum+=c[i][j];
}
}
return sum;
}
/*
* 区间和查询
*/
public int sum(int x1,int y1,int x2,int y2){
return sum(x2,y2)-sum(x1-1,y2)-sum(x2,y1-1)+sum(x1-1,y1-1);
}
/*
* 点更新
*/
public void add(int x,int y,int delta){//arr[x][y]加上delta
for(int i=x;i<=n;i+=lowbit(i)){
for(int j=y;j<=n;j+=lowbit(j)){
c[i][j]+=delta;
}
}
}
public static void main(String[] args) {
int[][] arr= {
{1,2,3,4},
{5,6,7,8},
{9,10,11,12},
{13,14,15,16}
};
TreeArray tArray=new TreeArray(arr);
for(int i=0;i<arr.length;i++) {
for(int j=0;j<arr[0].length;j++) {
tArray.add(i+1,j+1, arr[i][j]);
}
}
System.out.println("c数组: "+Arrays.toString(tArray.c));
for(int i=1;i<=arr.length;i++) {
for(int j=1;j<=arr[0].length;j++) {
System.out.println("sum["+i+","+j+"]="+tArray.sum(i,j));
}
}
System.out.println("sum(1,2,3,4)="+tArray.sum(1,2, 3, 4));//(1,2)[2]->(3,4)[12]
}
}
3. 树状数组和普通数组的比较
树状数组 | 普通数组 | |
---|---|---|
前缀和查询 | O(logn) | O(n) |
区间和查询 | O(logn) | O(n) |
点更新 | O(1) | O(logn) |
点查询 | O(1) | O(logn) |
参考:《算法训练营》高级篇