树状数组(Fenwick tree,又名binary indexed tree),是一种很实用的数据结构。它通过用节点i,记录数组下标在[ i –2^k + 1, i]这段区间的所有数的信息(其中,k为i的二进制表示中末尾0的个数,设lowbit(i) = 2^k),实现在O(lg n) 时间内对数组数据的查找和更新。
树状数组的传统解释图,不能很直观的看出其所能进行的更新和查询操作。其最主要的操作函数lowbit(k)与数的二进制表示相关,本质上仍是一种二分。因而可以通过二叉树,对其进行分析。事实上,从二叉树图,我们对它所能进行的操作和不能进行的操作一目了然。
和前面提到的点树类似,先画一棵二叉树,然后对节点中序遍历(点树是采用广度优先),每个节点仍然只记录左子树信息,见图:
由于采用的是中序遍历,从节点1到节点k时,刚好有k个叶子被统计。
可以证明:
叶子k,一定在节点k的左子树下。
以节点k为根的树,其左子树共有叶子lowbit(k)
节点k的父节点是:k + lowbit(k) 或 k - lowbit(k)
节点k + lowbit(k) 是节点k的最近父节点,且节点k在它的左子树下。
节点k - lowbit(k) 是节点k的最近父节点,且节点k在它的右子树下。
节点k,统计的叶子范围为:(k - lowbit(k), k]。
节点k的左孩子是:k - lowbit(k) / 2
下面分析树状数组两面主要应用:
1 更新数据x,进行区间查询。
2 更新区间,查询某个数。
由于,树状数组只统计了左子树的信息,因而只能查询更新区间[1, x]。只在在满足[x,y]的信息可以由[1,x-1]和[1,y]的信息推导出时,才能进行区间[x,y]的查询更新。这也是树状数组不能用于任意区间求最值的根本原因。
先定义两个集合:
up_right(k) : 节点k所有的父节点,且节点k在它们的左子树下。
up_left(k) : 节点k所有的父节点,且节点k在它们的右子树下。
1 更新数据x,查询区间[1,y]。
显然,更新叶子x,要找出叶子x在哪些节点的左子树下。因而节点k、所有的up_right(k)
都要更新。
查询[1, y],实际上就是把该区间拆分成一系列小区间,并找出统计这些区间的节点。可以通过找出y在哪些节点的右子树下,这些节点恰好不重复的统计了区间[1, y-1]。因而要访问节点y、所有的up_left(y)。
2 更新区间[1,y],查询数据x
这和前面的操作恰好相反。与前面的最大不同之处在于:节点保存的不再是其叶子总个数这些信息,而是该区间的所有叶子都改变了多少。也就是说:每个叶子的信息,分散到了所有对它统计的节点上。因此操作和前面相似:
更新[1,y]时,更新节点y、所有up_left(y)。
查询x时, 访问x、所有up_right(x)。
前面的树状数组,只对左子树信息进行统计,如果从后往前读数据初始化树状数组,则变成只对右子树信息进行统计,这时更新和查询操作,刚好和前面的相反。
一般情况下,树状数组比点树省空间,对区间[1, M]只要M+1空间,查询更新时定位节点比较快,定位父节点和左右孩子相对麻烦点(不过,一般也不用到。从上往下查找,可参考下面代码中的erease_nth函数(删除第n小的数))。
下面是使用树状数组的实现代码(求逆序数和模拟约瑟夫环问题):
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
![ExpandedBlockStart.gif](https://images.cnblogs.com/OutliningIndicators/ExpandedBlockStart.gif)
#include < cstdio >
#include < cstring >
#include < cassert >
template < int N > struct Round2k
{ enum { down = Round2k < N / 2u > ::down * 2 }; };
template <> struct Round2k < 1 > { enum { down = 1 }; };
template < int Total, typename T = int > // 区间[1, Total]
class BIT {
enum { Min2k = Round2k < Total > ::down};
T info[Total + 1 ];
T sz; // 可以用info[0]储存总大小
public :
BIT() { clear(); }
void clear() { memset( this , 0 , sizeof ( * this ));}
int size() { return sz; }
int lowbit( int idx) { return idx & - idx;}
// 寻找最近的父节点,left_up/right_up 分别使得idx在其右/左子树下
void left_up( int & idx) { idx -= lowbit(idx); }
void right_up( int & idx) { idx += lowbit(idx); }
void update( int idx , const int val = 1 ) { // 叶子idx 改变val个
assert(idx > 0 );
sz += val;
for (; idx <= Total; right_up(idx)) info[idx] += val;
}
void init( int arr[], int n) { // arr[i]为叶子i+1的个数
assert(n <= Total);
sz = n;
// for (int i = 0; i < n; ) {
// info[i + 1] = arr[i];
// if (++i >= n) break;
// info[i + 1] = arr[i];
// ++i;
// for (int j = 1; j < lowbit(i); j *= 2u) info[i] += info[i - j];
// }
for ( int i = 0 ; i < n; ) {
info[i + 1 ] = arr[i];
if ( ++ i >= n) break ;
int sum = arr[i];
int pr = ++ i;
left_up(pr);
for ( int j = i - 1 ; j > pr; left_up(j)) sum += info[j];
info[i] = sum;
}
}
int count( int idx) { // [1,idx] - [1, idx-1]
assert(idx > 0 );
int sum = info[idx];
// int pr = idx; // int pr = idx - lowbit(idx);
// left_up(pr);
// for (--idx; idx > pr; left_up(idx)) sum -= info[idx]; //
// return sum;
for ( int j = 1 ; j < lowbit(idx); j *= 2u ) sum -= info[idx - j];
return sum;
}
int lteq( int idx) { // 小等于
assert(idx >= 1 && idx <= Total);
int sum = 0 ;
for (; idx > 0 ; left_up(idx)) sum += info[idx];
return sum;
}
int gt( int idx) { return sz - lteq(idx); } // 大于
int operator []( int n) { return erase_nth(n, 0 ); } // 第n小
int erase_nth( int n, const bool erase_flag = true ) // 删除第n小的数
{
assert(n >= 1 && n <= sz);
sz -= erase_flag;
int idx = Min2k; // 从上往下搜索,先定位根节点
for ( int k = idx / 2u ; k > 0 ; k /= 2u ) {
int t = info[idx];
if (n <= info[idx]) { info[idx] -= erase_flag; idx -= k;} // 进入左子树
else {
n -= t;
if (Total != Min2k && Total != Min2k - 1 ) // 若不是完全二叉树
while (idx + k > Total) k /= 2u ; // 则必须计算右孩子的编号
idx += k; // 进入右子树
}
}
assert(idx % 2u ); // 最底层节点m一定是奇数,有两个叶子m,m+1
if (n > info[idx]) return idx + 1 ; // 节点m+1前面已经更新过
info[idx] -= erase_flag;
return idx;
}
void show()
{
for ( int i = 1 ; i <= Total; ++ i)
if (count(i)) printf( " %2d " , i);
printf( " \n " );
}
};
void ring() // 约瑟夫环
{
const int N = 17 ; // N个人编号:1,2, ... N
const int M = 7 ; // 报数:1到M,报到M的出列
printf( " N: %d M: %d\n " , N, M);
BIT < N > pt;
// for (int i = 0; i < N; ++i) pt.update(i + 1);
int arr[N];
for ( int i = 0 ; i < N; ++ i) arr[i] = 1 ;
pt.init(arr, N);
for ( int j = N, k = 0 ; j >= 1 ; -- j) {
k = (k + M - 1 ) % j;
int t = pt.erase_nth(k + 1 );
printf( " turn: %2d out: %2d rest: " , N - j, t);
pt.show();
}
printf( " \n\n " );
}
int ra( int arr[], int len) // 求逆序数-直接搜索
{
int sum = 0 ;
for ( int i = 0 ; i < len - 1 ; ++ i)
for ( int j = i + 1 ; j < len; ++ j)
if (arr[i] > arr[j]) ++ sum;
return sum;
}
template < int N >
int rb( int arr[], int len) // 求逆序数-使用树状数组
{
BIT < N > pt;
int sum = 0 ;
for ( int i = 0 ; i < len; ++ i) {
pt.update(arr[i] + 1 );
sum += pt.gt(arr[i] + 1 );
}
return sum;
}
int main()
{
int arr[] = { 4 , 3 , 2 , 1 , 0 , 5 , 1 , 3 , 0 , 2 };
const int N = sizeof (arr) / sizeof (arr[ 0 ]);
printf( " %d %d\n\n " , ra(arr, N), rb < 6 > (arr, N));
ring();
}