给定一个数列:a1 , a2 , a3 , ····,an。树状数组可以快速的完成下述操作:
- 给定i,计算a1+a2+ ···· + ai的和,即计算数列前i项的和,同样的这也可以转变为求任意给定的区间的和,如求[s,t]内的数列和,getSum(t) - getSum(s)。
- 给定i和x,执行ai += x,也就是update(i),注意由于树状数组是区间性的,所以i之后的元素也会更新。
令这棵树的结点编号为C1,C2…Cn。令每个结点的值为这棵树的值的总和,那么容易发现:
C1 = A1
C2 = A1 + A2
C3 = A3
C4 = A1 + A2 + A3 + A4
C5 = A5
C6 = A5 + A6
C7 = A7
C8 = A1 + A2 + A3 + A4 + A5 + A6 + A7 + A8
就是说,对输入的数据(存储在数组A中)进行了数据处理,将其存到了数组C中,而在二进制下就会发现数组C的下标对应的元素与A有关系。
设节点编号为x,那么这个节点在数组C中管辖的区间为2^k(其中k为x二进制末尾0的个数)个元素。因为这个区间最后一个元素必然为Ax,
所以很明显:Cn = A(n – 2^k + 1) + … + An
也就是说求Ci,就是不断的把数组A在i位置的元素加到结果中,并从i中减去i的二进制最低非0位对应的幂,直到i变成0为止。i的二进制的最后一个1可以通过i&-i得到,也就是i -= i & -i 。
import java.util.Arrays;
import java.util.Scanner;
public class Main {
public static int n;
public static int[] nums = null;
/*使第i项的值增加v需要从i开始,不断的把当前位置i的值增加v,并把i的二进制最低非0位对应的幂加到i上。*/
public static void update(int i,int v){
while(i<=n){
nums[i] += v;
i += i & -i;
}
}
public static int getSum(int i){
int sum = 0;
while(i>0){
sum += nums[i];
i -= i & -i;
}
return sum;
}
public static void main(String[] args) {
// TODO Auto-generated method stub
Scanner in = new Scanner(System.in);
while(in.hasNext()){
n = in.nextInt();
nums = new int[n+10];
//由于累加,初始化为0
Arrays.fill(nums, 0);
int t = 0;
//数组元素从1开始输入。
for(int i=1;i<=n;i++){
t = in.nextInt();
update(i, t);
}
for(int i=1;i<=n;i++) System.out.println(i+","+getSum(i));
}
}
}
树状数组通过 lowbit(k) = k & -k 找到相应的位,然后根据是update还是getSum执行加或减。
应用,树状数组求逆序数:
import java.util.Arrays;
import java.util.Scanner;
public class Main {
public static int n;
public static int[] nums = null;
public static void update(int i,int v){
while(i<=n){
nums[i] += v;
i += i & -i;
}
}
public static int getSum(int i){
int sum = 0;
while(i>0){
sum += nums[i];
i -= i & -i;
}
return sum;
}
public static void main(String[] args) {
// TODO Auto-generated method stub
Scanner in = new Scanner(System.in);
while(in.hasNext()){
n = in.nextInt();
nums = new int[n+10];
Arrays.fill(nums, 0);
int t = 0;
long ans = 0L;
for(int i=0;i<n;i++){
t = in.nextInt();
ans += i - getSum(t);
update(t, 1);
}
//for(int i=1;i<=n;i++) System.out.println(i+","+getSum(i));
System.out.println(ans);
}
}
}