树状数组的作用:
主要用于数组的单点修改&&区间求和
我们求数组的和一般是遍历,时间是O(n),每次修改一个数时间也是O(n),树状数组的作用就是优化这个过程
树状数组的思路:
假设我们要计算比每个数大或小的数有几个 (题目)
树状数组(如t[8])其实只有一个核心思想:用二进制形式存储数的和并进行运算
什么意思?
即我们每有一个数(如5),它的二进制为101,那么首先t[5]++,然后5加它本身最低位的1代表的数(1)即101(5)+001(1)=110(6),变成110(6),那么t[6]++,然后110(6)+010(2)=1000(8),t[8]++(终止位置自己设置),停了
代码实现:
int add(int x)//把包含这个数的结点都更新
{
while(x<=n)//范围
{
t[x]++;
x+=lowbit(x);
}
}
我们一般的数组是这样的(我用wps画的,有点丑别介意):
当我们用例子那样来对这个数组进行加法,那么就变成了
用二进制形式看
树状数组说每个 t[i] 存储的是从 t[1] 到 t[i] 的和,其实这种说法是不准确的,在树状数组中,每一个二进制最首位为1其他位为0的数字 i 才是从 t[1] 到 t[i] 的和,其他的数字要将它二进制除首位以外的1对应位的 t[i] 累加才是 t[1] 到 t[i] 的和(如5(101),ans+=t[5],5再减掉最低位1代表的十进制数,变成100(4),ans+= t[4] 此时4再减就变成0,结束,小于等于5的数有ans个)
代码实现:
int sum(int x)//查询1~X的和
{
int res=0;
while(x>=1)
{
res+=t[x];
x-=lowbit(x);
}
return res;
}
啊,对了,说了这么多,都没说怎么找到最低位,前面的lowbit就是自己写的找最低位1代表的数的函数
int lowbit(int x)
{
return x&-x;
}
lowbit为什么这么写?
我们都知道(我学这个之前不知道)一个数的负数等于这个数取反(1变0,0变1)加1
拿5举例:
5是101,5首先取反变成010,再加1变成011,然后与5本身相与,即为001,道理:取反后加一把后面第一个为0的数变成1,即找到原本的数第一个1,因为相反数第一个1后面全是0,相与都为0,所以得到的是最低位的1
加个例题吧
附上ac代码
#include<bits/stdc++.h>
using namespace std;
#define ll long long
ll ans[100010];
ll kid[100010]; //kid存储原数组
ll bit[1000010]; //bit存储树状数组
ll n;
void add(ll x)
{
ll i;
for(i=x;i<=n;i+=i&-i)
bit[i]++;
}
ll psum(ll x)
{
ll i,sum;
sum=0;
for(i=x;i>0;i-=i&-i)
sum+=bit[i];
return sum;
}
int main()
{
ll m,i;
long long x;
scanf("%lld",&m);
for(i=0;i<m;i++)
{
scanf("%lld",&kid[i]);
kid[i]++;
n=max(n,kid[i]);
}
for(i=0;i<m;i++)
{
ans[i]=psum(n)-psum(kid[i]); //这时ans是在输入第i个小朋友前比他高的人
add(kid[i]);
}
memset(bit,0,sizeof(bit));
for(i=m-1;i>-1;i--)
{
ans[i]+=psum(kid[i]-1); //这时ans+=所有在他后面比他矮的人
add(kid[i]);
}
x=0;
for(i=0;i<m;i++)
x+=(1+ans[i])*ans[i]/2;
printf("%lld",x);
return 0;
}
哦哦哦,对了,树状数组有个离散化处理数组的知识点
什么是数组离散化?
数组离散化就是说不关一个数组中这个数的具体大小,只管它在这个数组中的相对大小,如[0,1,1000000,2],数组离散化之后就是[1,2,4,3],即只按你是第几大排序
作用:减少树状数组的运算次数
具体代码如下:(这个就是按a数组排序1到n,排序方式按a数组排)
#include<bits/stdc++.h>
#define M 500005
using namespace std;
int a[M],d[M],t[M],n;
int lowbit(int x)
{
return x&-x;
}
int add(int x)//把包含这个数的结点都更新
{
while(x<=n)//范围
{
t[x]++;
x+=lowbit(x);
}
}
int sum(int x)//查询1~X的和
{
int res=0;
while(x>=1)
{
res+=t[x];
x-=lowbit(x);
}
return res;
}
bool cmp(int x,int y)//离散化比较函数
{
if(a[x]==a[y]) return x>y;//避免元素相同
return a[x]>a[y];//按照原序列第几大排列
}
int main()
{
long long ans=0;
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
d[i]=i;
}
sort(d+1,d+n+1,cmp);//离散化
for(int i=1;i<=n;i++)
{
add(d[i]);//把这个数放进去
ans+=sum(d[i]-1);//累加
}
printf("%d",ans);
return 0;
}