数状数组是一种很有用的数据结构,模板代码非常简单,只要知道如何运用,就可以轻松的掌握了。
下面给出数状数组的模型:
它的原理是利用了计算机的位运算的小技巧和前缀和的概念,也就是说,将我们有的数据作为叶子节点,同时用一个前缀数组来维护前缀和,这样,在我们求某一区间的和时,会非常的方便。
下面给出代码的模板:
int lowbit(int x) //==2^k k为从最低位到最高位,连续0的长度
{
return x&(-x);
}
int getsum(int x)//求前x项的和
void update(int x,int value)
{
for(int i=x;i<=n;i+=lowbit(i))
{
c[i]+=value;
}
}
void update(int x,int value)//对叶子数组的x进行修改后,需要更新前缀数组
{
for(int i=x;i<=n;i+=lowbit(i))
{
c[i]+=value;
}
}
下面是一道利用数状数组的例题:
Description
Ultra-QuickSort produces the output
Your task is to determine how many swap operations Ultra-QuickSort needs to perform in order to sort a given input sequence.
Input
Output
Sample Input
5 9 1 0 5 4 3 1 2 3 0
Sample Output
6 0
题意是让你求冒泡排序的交换次数,但是冒泡排序复杂度为O(n^2)这样明显超时了,所以需要对算法进行优化。求交换次数,也就等价于求这个数组中有多少个逆序对(也就是a[i+1]>a[i],<a[i],a[i+1]>),我们首先想到的是遍历每个数,求它之前有多少个数比它大,在求和。但是这样还是超时。
所以,该如何优化呢?
我们可以这样想,第i个数,前面有多少个比他大,不就等价于i-(前面比它小的个数)
好的,现在思路确定了,那么如何实现呢?
我们又发现,我们只需要把原数组放到数状数组里面,对每个数进行查询getsum(a[i]),返回的值,就是前i个数中比a[i]小的数的个数。(注意,我们之所以可以这样做,是因为我们选择了边查询边建树,你可以看到,数据 9 1 0 5 4,我们要查的第一个是9,getsum(9),但此时树全部为0,返回0,这与我们想要的结果一致,这时,我们在来更新树,将9放入树中,下次我们查询getsum(1),在1之前,比它小的数没有,代表着比它大的数有1个(这里就是9)......依次这样推理下去就好了)
前面的内容可能有点难以理解,需要一点时间思考,主要是要理解前i项和维护的是什么。
当然,这个题,还可以进行优化,那就是比如一个数a=99999999999999999,那我要去查getsum(99999999999999)
吗?当然没必要!
所以我们需要离散化一下,9 1 0 5 4
我们可以离散化为 5 2 1 4 3 这样,最高位从9变成了5,空间就可以节省下来了。
下面是离散化的解法代码:
#include <stdio.h>
#include <string.h>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <set>
#include <map>
#include <string>
#include <math.h>
#include <stdlib.h>
#include <time.h>
typedef long long ll;
using namespace std;
const int maxn = 500010 ;
int a[maxn],c[maxn];
int n;
struct P
{
int value;
int id;
};
int lowbit(int x) //==2^k k为从最低位到最高位,连续0的长度
{
return x&(-x);
}
int getsum(int x)
{
int ans=0;
for(int i=x;i>0;i-=lowbit(i))//减去间隔
{
ans+=c[i];
}
return ans;
}
void update(int x,int value)
{
for(int i=x;i<=n;i+=lowbit(i))
{
c[i]+=value;
}
}
bool cmp(P a,P b)
{
return a.value<b.value;
}
int main()
{
struct P t[maxn];
while(scanf("%d",&n)&&n)
{
for(int i=1;i<=n;i++)
{
scanf("%d",&t[i].value);
t[i].id=i;
}
sort(t+1,t+n+1,cmp);//离散化
for(int i=1;i<=n;i++)
{
a[t[i].id]=i;
}
memset(c,0,sizeof(c));
ll ans=0;
for(int i=1;i<=n;i++)
{
update(a[i],1);
ans+=(i-getsum(a[i]));
}
cout<<ans<<endl;
}
return 0;
}