主席树,又叫可持久化线段树,一种可持久化的数据结构。
一种基本用处是查询区间中排名为k的数字;还有一种是普通的线段树操作区间修改,区间查询,或者区间历史查询;另外还有一种用处是求区间有多少个不同的数字
它们的建树方式有所不同,
第一种的建树方式是先把所有数字去重+离散化,然后得到不重复数字的个数n,然后[l,r]这个节点里存的是当前前缀中[a[l] , a[r]]有多少个数字,这里跟普通的线段树存的东西不太一样,有点类似于权值线段树,不过是离散之后的权值线段树。
第二种的建树方式就是和线段树一样,然后每次修改新增一个 根,把需要修改的节点新建出来(利用lazy可以使得每次新建的节点不超过log n 个)。然后就可以回到过去查询历史的区间信息。有个地方要注意一下,这里的lazy实际上没有pushdown操作,而是标记了区间被加了多少,然后在查询的时候,把这个区间的值层层累加传递下去,直到目标区间,这样减少了pushdown中新增的lazy节点,节省了大量的空间。
第三种的建树方式和第二种类似,如果可以离线的话,我们知道可以让r排序,然后从小到大枚举r。假设当前枚举到r,a[r]前一次出现的位置记为p,p位置减去1,r位置加上1,然后查询[L,R]的和就是这一段有多少个不同数字出现。但是如果是在线的话,当我们询问到r的时候,会把前面某一个点减去,假设下次询问r-x时,就有可能这个点已经被减去导致结果变小。 所以一棵线段树是不够的,我们考虑建n棵线段树,但是每次插入一个 r ,实际上只会新增两条链,(如果该数字没有出现过则是一条)。
然后每次查询[l,r],就是以第r棵线段树的根的[L,R]这段的和。
贴一下几种用法的模板题
第一种用法的模板题,hdu2665
#include<cmath>
#include<algorithm>
#include<cstring>
#include<string>
#include<set>
#include<map>
#include<time.h>
#include<cstdio>
#include<vector>
#include<list>
#include<stack>
#include<queue>
#include<iostream>
#include<stdlib.h>
using namespace std;
#define LONG long long
const int INF=0x3f3f3f3f;
const LONG MOD=1e9+ 7;
const double PI=acos(-1.0);
#define clrI(x) memset(x,-1,sizeof(x))
#define clr0(x) memset(x,0,sizeof x)
#define clr1(x) memset(x,INF,sizeof x)
#define clr2(x) memset(x,-INF,sizeof x)
#define EPS 1e-10
#define lson l , mid , rt<< 1
#define rson mid + 1 ,r , (rt<<1)+1
#define root 1, n , 1
const int MAXN = 1e5 ;
struct Tree{
int l, r ;
int val ;
}tree[MAXN *30+ 30];
int N , n ;
int a[100100] ;
int Root[100100] ;
int tot = 0;
int num[100100] ;
void Push_up(int rt )
{
tree[rt].val = tree[tree[rt].l].val + tree[tree[rt].r].val ;
}
int Build(int l, int r)
{
int rt = tot ++ ;
if(l == r)
{
tree[rt].val = 0 ;
return rt ;
}
int mid = (l + r) / 2;
tree[rt].l = Build(l , mid ) ;
tree[rt].r = Build(mid + 1, r) ;
Push_up(rt) ;
return rt ;
}
void Hash()
{
sort(a + 1, a +N+1 ) ;
n = unique(a +1 , a + N +1 ) - a - 1;
}
int Update(int l ,int r , int rt , int x)
{
tot ++ ;
int now = tot ;
tree[now] = tree[rt] ;
if(l ==r )
{
tree[now].val ++ ;
return now;
}
int mid = ( l + r ) / 2;
if( x <= a[mid])
tree[now].l = Update(l , mid , tree[now].l ,x) ;
else
tree[now].r = Update(mid + 1, r , tree[now].r , x) ;
Push_up( now ) ;
return now ;
}
int Que( int l , int r , int R_rt ,int L_rt , int k)
{
if( ( l == r) )
return a[l];
int mid = (l +r ) / 2;
if(k > tree[tree[R_rt].l].val - tree[tree[L_rt].l].val )
return Que(mid+1 ,r , tree[R_rt].r,tree[L_rt].r , k - tree[tree[R_rt].l].val + tree[tree[L_rt].l].val ) ;
else
return Que(l , mid ,tree[R_rt].l, tree[L_rt].l , k ) ;
}
int main()
{
int T ;
cin >> T ;
int m ;
while(T --)
{
tot = 0 ;
scanf("%d%d",&N ,&m);
for(int i =1 ; i <= N ;++ i)
scanf("%d",&a[i]),num[i] = a[i] ;
Hash() ;
int now = 0 ;
Root[0] = Build(1 , n );
int p ;
for(int i = 1;i<= N ; ++ i)
Root[i] = Update(1 , n , Root[i-1] , num[i] ) ;
int K ;
int l , r ;
while(m --)
{
scanf("%d%d%d",&l,&r, &K) ;
cout<<Que(1,n,Root[r],Root[l-1] , K) <<endl ;;
}
}
}
第二种用法区间更新,区间查询,区间历史查询,hdu4348
#include<cmath>
#include<algorithm>
#include<cstring>
#include<string>
#include<set>
#include<map>
#include<time.h>
#include<cstdio>
#include<vector>
#include<list>
#include<stack>
#include<queue>
#include<iostream>
#include<stdlib.h>
using namespace std;
#define LONG long long
const int INF=0x3f3f3f3f;
const LONG MOD=1e9+ 7;
const double PI=acos(-1.0);
#define clrI(x) memset(x,-1,sizeof(x))
#define clr0(x) memset(x,0,sizeof x)
#define clr1(x) memset(x,INF,sizeof x)
#define clr2(x) memset(x,-INF,sizeof x)
#define EPS 1e-10
#define lson l , mid , rt<< 1
#define rson mid + 1 ,r , (rt<<1)+1
#define root 1, m , 1
const int MAXN = 100000*30+55;
struct Tree{
LONG Sum ;
int l ,r ;
LONG lazy ;
}tree[MAXN ];
int tot = 0;
int Root[101000] ;
int now =0 ;
int Build(int l ,int r )
{
tot ++ ;
int rt = tot ;
tree[rt].lazy = 0 ;
if ( l == r)
{
scanf("%lld" , &tree[rt].Sum) ;
return rt ;
}
int mid = (l+r) / 2;
tree[rt].l = Build(l,mid ) ;
tree[rt].r = Build(mid+1 , r );
tree[rt].Sum = tree[tree[rt].l].Sum + tree[tree[rt].r].Sum ;
return rt ;
}
int Update(int L, int R , int l ,int r ,int rt , LONG val )
{
tot ++ ;
int now = tot ;
tree[now] = tree[rt] ;
if( L <= l && r <= R)
{
tree[ now ].Sum += ( r - l + 1) * val ;
tree[ now ].lazy += val ;
return now ;
}
int mid = ( l +r )/ 2;
if( L <= mid)
tree[now].l = Update(L, R , l,mid, tree[now].l , val) ;
if(R > mid )
tree[now].r = Update(L ,R , mid + 1 ,r , tree[now].r ,val ) ;
tree[now].Sum = tree[tree[now].l ].Sum + tree[tree[now].r].Sum + ( r - l + 1) * tree[now].lazy ;
return now ;
}
LONG Query(int L ,int R , int l , int r ,int rt ,LONG val )
{
if(L <= l && r <= R)
return tree[rt].Sum + ( r - l + 1) *val ;
int mid = ( l +r )/ 2;
LONG ans = 0;
if(L <= mid)
ans += Query(L , R , l , mid , tree[rt].l , val + tree[rt].lazy) ;
if(R > mid )
ans += Query(L ,R , mid + 1 ,r ,tree[rt].r , val + tree[rt].lazy) ;
return ans ;
}
int main()
{
int n , m;
while(~scanf("%d%d",&n,&m))
{
tot= 0 ;
now= 0;
Root[now] = 1;
char op[5] ;
int l ,r ;
LONG val ;
int t ;
Root[now] = Build(1 , n ) ;
while(m -- )
{
scanf("%s" , op);
if(op[0] =='C')
{
scanf("%d%d%lld",&l,&r ,&val) ;
Root[++now] = Update(l ,r , 1 , n ,Root[now-1] , val) ;
}
else if(op[0] == 'Q')
{
scanf("%d%d",&l, &r ) ;
printf("%lld\n",Query(l ,r ,1,n,Root[now] , 0) ) ;
}
else if(op[0] == 'H')
{
scanf("%d%d%d",&l ,&r ,&t) ;
printf("%lld\n",Query(l , r ,1 , n , Root[t] , 0) ) ;
}
else
scanf("%d",&now ) ;
}
}
}
求区间不同数字个数的用法spoj DQUERY
#include<bits/stdc++.h>
using namespace std;
#define LONG long long
const int INF=0x3f3f3f3f;
const LONG MOD=1e9+ 7;
const double PI=acos(-1.0);
#define clrI(x) memset(x,-1,sizeof(x))
#define clr0(x) memset(x,0,sizeof x)
#define clr1(x) memset(x,INF,sizeof x)
#define clr2(x) memset(x,-INF,sizeof x)
#define EPS 1e-10
const int MAXN = 3e4+10 ;
struct Tree
{
int l ,r ;
int sum ;
}tree[MAXN*30];
int Root[30100] ;
int pos[1001000] ;
int tot = 0;
int Build(int l ,int r)
{
tot ++ ;
int now = tot ;
tree[now].sum = 0 ;
if( l == r)return now ;
int mid = (l + r ) / 2;
tree[now].l = Build( l , mid ) ;
tree[now].r = Build( mid + 1, r ) ;
return now ;
}
int Update(int p, int l ,int r ,int rt , int val)
{
tot ++ ;
int now = tot ;
tree[now] = tree[rt] ;
if( r == l && p == l)
{
tree[now].sum += val ;
return now ;
}
int mid = ( l + r) / 2 ;
if( p <= mid)
tree[now].l = Update(p , l , mid , tree[now].l , val ) ;
else
tree[now].r = Update( p , mid + 1, r , tree[now].r , val) ;
tree[now].sum = tree[tree[now].l ].sum + tree[tree[now].r].sum ;
return now ;
}
int Que(int L ,int R , int l ,int r , int R_rt )
{
if( L <= l && r <= R )
return tree[R_rt].sum ;
int mid = ( l + r ) / 2;
int res = 0 ;
if(L <= mid )
res += Que( L , R , l ,mid , tree[R_rt].l ) ;
if(R > mid)
res += Que(L , R , mid + 1, r , tree[R_rt].r ) ;
return res ;
}
int main()
{
int n ;
while(cin >> n)
{
clr0(pos) ;
int tmp ,p , q ;
Root[0 ] = Build(1 , n ) ;
for(int i = 1;i <= n ;++ i)
{
scanf("%d",&tmp) ;
p = pos[tmp] ;
Root[i] = Update( i , 1 , n , Root[i-1],1 ) ;
if(p != 0)
Root[i] = Update(p , 1 , n , Root[i],-1 ) ;
pos[tmp] = i ;
}
cin >> q ;
int l ,r ;
int res2 ,res1 ;
while(q --)
{
scanf("%d%d",&l,&r);
res1 = Que( l , r , 1 , n , Root[r] ) ;
printf("%d\n",res1 ) ;
}
}
}