首先说一下我对主席树的理解:
其实主席树就是一棵可持久化权值线段树。权值线段树就是意思是将值作为线段树的L,R下标,而Sum属性代表的是当前M下标的值有多少个。
可持久化,其实就是保存历史版本的意思。相当于每做一次更新操作,把原来的版本复制一份,再修改一些节点。
但是问题来了直接复制肯定大量空间浪费,就相当于开了若干棵线段树,直接爆炸。所以我们要尽可能使用之前的信息。只要之前的节点不变的,就将当前版本的节点直接连到之前节点,不用新开节点。
这样的话每次修改只涉及一个值的话,那么修改的路径最多是树高,是logn的。那么每次新开的节点不超过logn,大大节省了空间。
以上就是主席树的基本思想。
然后入门视频推荐一下UESTC的讲解视频,虽然说讲解的也不是特别特别清楚,但是那个代码我感觉还是挺不错的。再结合手推基本可以理解精髓所在。
无修改的主席树可以理解为前缀和+权值线段树。
入门题:
HDU2665 求区间第K大
题解:模板题。建好主席树后,查询就是相当于利用前缀和每次计算操作在L-R之间的数的个数,对于满足第K大,二分查找即可。而主席树本身具有二分的性质,直接根据左儿子的sum值(即比当前M小于等于的值的个数)和k比较,就可以判断往哪边走了。
1 #include<iostream> 2 #include<cstdio> 3 #include<cstring> 4 #include<algorithm> 5 #include<vector> 6 #include<cmath> 7 #include<string> 8 #include<set> 9 #include<queue> 10 #include<map> 11 using namespace std; 12 const int inf=(1<<30)-1; 13 const int maxn=100010; 14 #define REP(i,n) for(int i=(0);i<(n);i++) 15 #define FOR(i,j,n) for(int i=(j);i<=(n);i++) 16 #define Rep(x) for(int i=head[x],y;~i;i=e[i].next) if(!vis[y=e[i].to]) 17 typedef long long ll; 18 typedef pair<int,int> PII; 19 int IN(){ 20 int c,f,x; 21 while (!isdigit(c=getchar())&&c!='-');c=='-'?(f=1,x=0):(f=0,x=c-'0'); 22 while (isdigit(c=getchar())) x=(x<<1)+(x<<3)+c-'0';return !f?x:-x; 23 } 24 #define de(x) cout << #x << "=" << x << endl 25 #define MP make_pair 26 #define PB push_back 27 #define fi first 28 #define se second 29 int n,m,T; 30 int a[maxn],rt[maxn],tot; 31 struct data{ 32 int l,r,sum; 33 }t[maxn*20]; 34 vector<int> v; 35 inline getid(int x) { 36 return lower_bound(v.begin(),v.end(),x)-v.begin()+1; 37 } 38 void build(int l,int r,int &x) { 39 x=++tot;t[x].sum=0; 40 if(l==r) return; 41 int m=(l+r)>>1; 42 build(l,m,t[x].l); 43 build(m+1,r,t[x].r); 44 } 45 void update(int l,int r,int &x,int y,int k) { 46 x=++tot;t[x]=t[y];t[x].sum++; 47 if(l==r) return; 48 int m=(l+r)>>1; 49 if(k<=m) update(l,m,t[x].l,t[y].l,k); 50 else update(m+1,r,t[x].r,t[y].r,k); 51 } 52 int query(int l,int r,int x,int y,int k) { 53 if(l==r) return l; 54 int m=(l+r)>>1; 55 int sum=t[t[y].l].sum-t[t[x].l].sum; 56 if(k<=sum) return query(l,m,t[x].l,t[y].l,k); 57 else return query(m+1,r,t[x].r,t[y].r,k-sum); 58 } 59 int main() 60 { 61 int T;scanf("%d",&T); 62 while(T--) { 63 scanf("%d%d",&n,&m); 64 for(int i=1;i<=n;i++) scanf("%d",&a[i]),v.PB(a[i]); 65 sort(v.begin(),v.end()); 66 v.erase(unique(v.begin(),v.end()),v.end()); 67 int cnt=v.size(); 68 tot=0; 69 build(1,cnt,rt[0]); 70 for(int i=1;i<=n;i++) { 71 update(1,cnt,rt[i],rt[i-1],getid(a[i])); 72 } 73 while(m--) { 74 int l,r,x;scanf("%d%d%d",&l,&r,&x); 75 printf("%d\n",v[query(1,cnt,rt[l-1],rt[r],x)-1]); 76 } 77 } 78 return 0; 79 }
HDU4417 求区间内 <= k 的数有多少
题解:只需要将查询部分修改一下,每次返回需要查询区间内的且值的范围在当前线段树下标内的数的个数。那么对于每个询问,只需要统计[0,k]的数的个数。所以我们如果k <= mid ,就往左走。如果说反之则左边的子树需要加入答案,并往右走。
1 #include<iostream> 2 #include<cstdio> 3 #include<cstring> 4 #include<algorithm> 5 #include<vector> 6 #include<cmath> 7 #include<string> 8 #include<set> 9 #include<queue> 10 #include<map> 11 using namespace std; 12 const int inf=(1<<30)-1; 13 const int maxn=100010; 14 #define REP(i,n) for(int i=(0);i<(n);i++) 15 #define FOR(i,j,n) for(int i=(j);i<=(n);i++) 16 #define Rep(x) for(int i=head[x],y;~i;i=e[i].next) if(!vis[y=e[i].to]) 17 typedef long long ll; 18 typedef pair<int,int> PII; 19 int IN(){ 20 int c,f,x; 21 while (!isdigit(c=getchar())&&c!='-');c=='-'?(f=1,x=0):(f=0,x=c-'0'); 22 while (isdigit(c=getchar())) x=(x<<1)+(x<<3)+c-'0';return !f?x:-x; 23 } 24 #define de(x) cout << #x << "=" << x << endl 25 #define MP make_pair 26 #define PB push_back 27 #define fi first 28 #define se second 29 int n,m,T; 30 int a[maxn],rt[maxn],tot; 31 struct data{ 32 int l,r,sum; 33 }t[maxn*20]; 34 vector<int> v; 35 int getid(int x) { 36 return upper_bound(v.begin(),v.end(),x)-v.begin(); 37 } 38 39 void build(int l,int r,int &x) { 40 x=++tot;t[x].sum=0; 41 if(l==r) return; 42 int m=(l+r)>>1; 43 build(l,m,t[x].l); 44 build(m+1,r,t[x].r); 45 } 46 void upd(int l,int r,int &x,int y,int k) { 47 x=++tot;t[x]=t[y];t[x].sum++; 48 if(l==r) return; 49 int m=(l+r)>>1; 50 if(k<=m) upd(l,m,t[x].l,t[y].l,k); 51 else upd(m+1,r,t[x].r,t[y].r,k); 52 } 53 int que(int l,int r,int x,int y,int k) { 54 if(l==r) return t[y].sum-t[x].sum; 55 int m=(l+r)>>1; 56 if(k<=m) return que(l,m,t[x].l,t[y].l,k); 57 else return (t[t[y].l].sum-t[t[x].l].sum)+que(m+1,r,t[x].r,t[y].r,k); 58 } 59 int main() 60 { 61 scanf("%d",&T); 62 int _=0; 63 while(T--) { 64 printf("Case %d:\n",++_); 65 scanf("%d%d",&n,&m); 66 for(int i=1;i<=n;i++) scanf("%d",&a[i]),v.PB(a[i]); 67 sort(v.begin(),v.end()); 68 v.erase(unique(v.begin(),v.end()),v.end()); 69 int cnt=v.size(); 70 tot=0; 71 build(1,cnt,rt[0]); 72 for(int i=1;i<=n;i++) { 73 upd(1,cnt,rt[i],rt[i-1],getid(a[i])); 74 } 75 while(m--) { 76 int l,r,x; 77 scanf("%d%d%d",&l,&r,&x); 78 l++;r++; 79 printf("%d\n",que(1,cnt,rt[l-1],rt[r],getid(x))); 80 } 81 } 82 return 0; 83 }
SPOJ DQUERY 求区间内不同的数的个数
题解:显然离线的做法,可以先将询问排序,然后利用树状数组+扫描线的做法解决。用主席树可以在线回答询问。因为主席树是可持久化的线段树。从左往右扫描位置,每次对当前新加入的数的位置i加上1,表示该数出现次数+1。若该数之前出现过,每次用map记录下该数上次出现的位置mp[a[i]],然后在该位置-1,表示最新的位置应该在当前。对于每个询问l,r,只要找到r所在的版本的线段树,即rt[r]。对该线段树询问>=l的和即可。
trick:主席树在做更新的时候,每次更新都会开一个新的节点tot。所以对于同次版本的多次更新应该用中间变量tmp,暂时储存结果。
1 #include<bits/stdc++.h> 2 using namespace std; 3 const int maxn=1e5+10; 4 const int inf=(1<<30)-1; 5 int n,m; 6 int a[maxn],rt[maxn],tot; 7 struct data{ 8 int l,r,sum; 9 }t[maxn*20]; 10 void build(int l,int r,int &x) { 11 x=++tot;t[x].sum=0; 12 if(l==r) return; 13 int m=(l+r)>>1; 14 build(l,m,t[x].l); 15 build(m+1,r,t[x].r); 16 } 17 void upd(int l,int r,int &x,int y,int k,int v) { 18 x=++tot;t[x]=t[y];t[x].sum+=v; 19 if(l==r) return; 20 int m=(l+r)>>1; 21 if(k<=m) upd(l,m,t[x].l,t[y].l,k,v); 22 else upd(m+1,r,t[x].r,t[y].r,k,v); 23 } 24 int que(int l,int r,int x,int y) { 25 if(l==r) return t[y].sum; 26 int m=(l+r)>>1; 27 if(x<=m) return que(l,m,x,t[y].l)+t[t[y].r].sum; 28 else return que(m+1,r,x,t[y].r); 29 } 30 void dfs(int l,int r,int x) { 31 if(l==r) {printf("t[%d]=%d\n",l,t[x].sum);return;} 32 int m=(l+r)>>1; 33 dfs(l,m,t[x].l); 34 dfs(m+1,r,t[x].r); 35 } 36 int main() { 37 while(~scanf("%d",&n)) { 38 for(int i=1;i<=n;i++) scanf("%d",&a[i]); 39 map<int,int> mp; 40 tot=0; 41 build(1,n,rt[0]); 42 for(int i=1;i<=n;i++) { 43 if(!mp[a[i]]) { 44 upd(1,n,rt[i],rt[i-1],i,1); 45 //dfs(1,n,rt[i]); 46 } 47 else { 48 //cout<<i<<" "<<mp[a[i]]<<endl; 49 int tmp; 50 upd(1,n,tmp,rt[i-1],mp[a[i]],-1); 51 upd(1,n,rt[i],tmp,i,1); 52 } 53 mp[a[i]]=i; 54 } 55 56 //for(int root=1;root<=n;root++) 57 //dfs(1,n,rt[root]); 58 59 scanf("%d",&m); 60 while(m--) { 61 int l,r;scanf("%d%d",&l,&r); 62 printf("%d\n",que(1,n,l,rt[r])); 63 } 64 } 65 return 0; 66 }