【算法】【树状数组】 原理详解、模板与例题

树状数组要解决什么问题

m次操作,每次操作可以查询数组区间长度为n的 和/积,或者更新某位置的值。

不用树状数组的几种处理方式:

  • 每次查询直接扫描该区间取和,一次取和时间复杂度为O(n)。更新时间复杂度为O(1)
  • dp思想提前获取前缀和数组,每次取和直接取区间结尾前缀和减区间开头前缀和的值,这样时间复杂度为O(1),但是更新时需要从更新位置更新前缀和数组,时间复杂度为O(n)
  • 线段树算法,树状数组能解决的线段树都能解决,但是线段树代码量较大,容易出错。取和和更新的时间复杂度都为O(logn)

树状数组取和和更新的时间复杂度都为O(logn)

树状数组的逻辑

下图为建立树状数组的过程及查询、更新的逻辑。本文以求区间和为例。
在这里插入图片描述

tree代表树状数组,arr代表原数组

如图,树状数组每个位置存储一段区间和。如
tree[1]=arr[1]
tree[2]=arr[1]+arr[2]
tree[3]=arr[3]
tree[4]=arr[1]+arr[2]+arr[3]+arr[4]

如何查询区间和

求区间[4,7]的和,则是求出[1,7]的和-[1,3]的和

流程

树状数组的特点是:树状数组的下标为管辖区间的结尾。求和需要将自己及之前的所有区间和相加,因此需要不断找上个区间。

图中蓝线是找上个区间的含义,

  • tree[7]+tree[6]+tree[4]得到区间[1,7]的和
  • tree[3]+tree[2]得到区间[1,3]的和

两者相减即为[4,7]和

如何查找上个区间树状数组的下标

位运算lowbit

由该数值二进制计算出只包含二进制结尾1的数值。计算方式是 m&(-m)。

7(0111)->7(0111)&-7(0001)->1(0001)
6(0110)->6(0110)&-6(0010)->2(0010)
查找上个位置

m-lowbit(m)=>m-m&(-m)

7->7-1->6
6->6-2->4
4->4-4->0

累计求和

不断找上个位置,直到位置0。累计加和即为结果

如何更新

如arr[1]的值从3变成7,相当于其结果增加了4,则依次更新树状数组的节点及父节点。

tree[1]->tree[1]+4
tree[2]->tree[2]+4
tree[4]->tree[4]+4
tree[8]->tree[8]+4

更新直至超出数组范围。

初始如何建树状数组

遍历原始数组,相当于每次对树状数组的值进行更新(数组初始元素为0)

模板代码

树状数组的工具类

/**
     * 下标从0开始的树状数组
     */
    class BIT {
        // 根据原始数组生成树状数组
        int[] treeArray;
        // 原始数组
        int[] originalArray;

        /**
         * 构造方法,
         * @param originalArray 原始数组
         */
        public BIT(int[] originalArray) {
            // 初始化树状数组
            treeArray = new int[originalArray.length];
            // 复制原始数组
            this.originalArray = Arrays.copyOf(originalArray,originalArray.length);
            // 构造树状数组,遍历原始数组,对树状数组每个位置更新
            for (int i = 0; i < originalArray.length; i++) {
                add(i, originalArray[i]);
            }
        }

        /**
         * 取x+1值的二进制最后一个1
         * 这里加一是为了能够使用下标从0开始的原始数组。
         * @param x 原始数组下标
         * @return 最后一个1代表的数值
         */
        private int lowBit(int x) {
            x++;
            return x & (-x);
        }

        /**
         * 查询区间[0,endPosition]的和
         * @param endPosition 结束位置
         * @return 取和
         */
        public int query(int endPosition) {
            // 查询范围小于0直接返回0
            if (endPosition < 0) {
                return 0;
            }
            int res = 0;
            // 从前往后找,将每段区间和相加
            while (endPosition >= 0) {
                res += treeArray[endPosition];
                // 计算存储上一个区间和的元素下标
                endPosition = endPosition - lowBit(endPosition);
            }
            return res;
        }

        /**
         * 原始数组某位置增加数值,更新树状数组
         * @param position 位置
         * @param value 增加的值
         */
        private void add(int position, int value) {
            while (position < treeArray.length) {
                treeArray[position] += value;
                position = position + lowBit(position);
            }
            originalArray[position] = value;
        }

        /**
         * 原始数组某位置修改数值,更新树状数组
         * @param position 位置
         * @param value 修改的值
         */
        public void set(int position, int value) {
            int diff = value - originalArray[position];
            add(position, diff);
        }
    }

使用例子

	private BIT bit;
	// 初始化
    public NumArray(int[] nums) {
        bit = new BIT(nums);
    }
	// 更新某位置值
    public void update(int index, int val) {
        bit.set(index,val);
    }
	// 取区间和
    public int sumRange(int left, int right) {
        return bit.query(right) - bit.query(left - 1);
    }

例题


17年写的C++代码,留个纪念
POJ3468

#include<iostream>
#include<stdio.h>
#include<algorithm>
#include<cmath>
#include<set>
#include<vector>
#include<map>
#include<string>
#include<stdlib.h>
#include<limits.h>
using namespace std;
/******************************************************/
#define LL long long int
#define mem(a,b) memset(a,b,sizeof(a))
#define m ((l+r)/2)
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define L rt<<1
#define R rt<<1|1
#define N 400000+1
#define pow(a) a*a
#define INF 0x3f3f3f3f
#define max(a,b) (a>b?a:b)
#define min(a,b) (a<b?a:b)
#define lowbit(x) (x&-x)
/*********************************************************/
/*sum[x] = org[1]+...+org[x] + delta[1]*x +
delta[2]*(x-1) + delta[3]*(x-2)+...+delta[x]*1
= org[1]+...+org[x] + segma(delta[i]*(x+1-i))
= segma(org[i]) + (x+1)*segma(delta[i]) - segma(delta[i]*i),1 <= i <= x*/
/***************************************************/
/*10 5
1 2 3 4 5 6 7 8 9 10
Q 4 4
Q 1 10
Q 2 4
C 3 6 3
Q 2 4*/
LL n, q;
LL dat[N];
LL summ[N];
LL bit0[N], bit1[N];
char s[2];
void add(LL arr[],LL x, LL d){
	while (x <= n){
		arr[x] += d;
		x += lowbit(x);
	}
}
LL find(LL x,LL su[]){
	LL res = 0;
	while (x > 0){
		res += su[x];
		x -= lowbit(x);
	}
	return res;
}
int main(){
	mem(summ, 0);
	mem(bit0, 0);
	mem(bit1, 0);
	scanf("%lld%lld", &n, &q);
	for (int i = 1; i <= n; i++)scanf("%lld", &dat[i]);
	summ[0] = 0;
	for (int i = 1; i <= n; i++){
		summ[i] = summ[i - 1] + dat[i];
	}
	while (q--){
		LL a, b;
		scanf("%s%lld%lld", s, &a, &b);
		if (s[0] == 'Q'){
			LL sum = summ[b] - summ[a - 1];
			LL sum1 =(b+1)* find(b, bit0) - find(b, bit1);
			LL sum2 = a*(find(a - 1, bit0)) - find(a - 1, bit1);
			printf("%lld\n", sum + sum1 - sum2);
		}
		else{
			LL c; scanf("%lld", &c);
			add(bit0, a, c);
			add(bit0, b + 1, -c);
			add(bit1, a, c*a);
			add(bit1, b + 1, -c*(b + 1));
		}
	}
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值