总是听到主席树这个高大上的名字,仰慕已久,今天开始简单地学习
主席树是一种可持续化数据结构,这个名字与算法本身并无太大关联,听说是以创造者命名的。
主席树,简单来讲,其实就是线段树。一堆的线段树,而且每一个只由前一个修改而来。
什么意思?我们来看个例子
你有一个区间{1,3,8,7,6,7,2} 【我滚出来的】
然后你以这个区间建了一个线段树,每个节点维护区间和【其他的也行】
然后你想像,你把6改成了10,按照线段树的操作,它将会影响叶子结点6往上logN个节点
如果我们这个时候想保留旧的线段树,而新建一个线段树来保存新的线段树,我们只需在改变的节点的地方新建节点来保存,其它没变的地方直接指向旧的线段树的节点就好了。
这样我们就用logN的空间就建了一个新的线段树,还保留了历史版本,对于那些有撤销操作的题目很有作用。【发明者真是高超】
区间第K大值
其实主席树最经典的应用,就是求区间第k大值
怎么做呢?
如果我们只求区间[1,N]的第K大值,我们只需将每个数出现的频率建成一个线段树
什么?
比如对于{1,1,2,3,4} 我们线段树的叶子节点权值是{2,1,1,1}分别代表1出现过2次,2出现过1次,3出现过1次,4出现过1次。
然后往上结点只需维护区间和,然后我们就可以根据每个节点左儿子的权值大小来判断我们要找的第k大数在哪一边了。
考虑到数可能很大,我们需要将他们离散化【就是用很小的数暂时代替较大的数而保持数与数之间的大小关系】
[1,N]第K大值我们解决了,那么对于任意区间呢?
如果对于任意区间,我们能建一个这样的线段树,那不就都能写出来了么。
当然可以,不过这样子时间和空间消费都会爆掉。。。
怎么办呢?
想想,从区间[1,i-1]到区间[1,i]不过多了A[i]这一个数,也就是说我们只对线段树进行了一次修改操作,只改变了logN个结点,那么我们套用上面刚讲到的主席树不就行了耶。
然后利用前缀和和思想,查询区间[l,r]就是用[1,r]减去[1,l-1],这样就完成了。
POJ上有一个模板题 POJ2104 K-th Number
贴上我的代码
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int maxn=100005,maxm=2000005,INF=200000000;
int N;
struct node{
int l,r,ls,rs,sum;
}e[maxm];
int A[maxn],B[maxn],root[maxn],pos,siz=0;
void build(int &u,int l,int r){
u=++siz;
e[u].l=l;
e[u].r=r;
if(l==r) return;
int mid=(l+r)>>1;
build(e[u].ls,l,mid);
build(e[u].rs,mid+1,r);
}
void insert(int pre,int& u){
u=++siz;
e[u]=e[pre];
e[u].sum++;
if(e[u].l==e[u].r) return;
int mid=(e[u].l+e[u].r)>>1;
if(mid>=pos) insert(e[pre].ls,e[u].ls);
else insert(e[pre].rs,e[u].rs);
}
int Query(int pre,int u,int k){
if(e[u].l==e[u].r) return B[e[u].l];
int sum=e[e[u].ls].sum-e[e[pre].ls].sum;
if(sum>=k) return Query(e[pre].ls,e[u].ls,k);
else return Query(e[pre].rs,e[u].rs,k-sum);
}
int main()
{
int M;
scanf("%d%d",&N,&M);
for(int i=1;i<=N;i++) scanf("%d",&A[i]),B[i]=A[i];
sort(B+1,B+1+N);
build(root[0],1,N);
for(int i=1;i<=N;i++){
pos=lower_bound(B+1,B+1+N,A[i])-B;
insert(root[i-1],root[i]);
}
int l,r,k;
while(M--){
scanf("%d%d%d",&l,&r,&k);
printf("%d\n",Query(root[l-1],root[r],k));
}
return 0;
}
洛谷上也有一道板题:
洛谷P1138 中位数
思路是一样的
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int maxn=100005,INF=200000000;
inline int read()
{
int out=0,flag=1;char c=getchar();
while(c<48||c>57) {if(c=='-') flag=-1;c=getchar();}
while(c>=48&&c<=57) {out=out*10+c-48;c=getchar();}
return out*flag;
}
int A[maxn],B[maxn],num[maxn],numi=0,h[maxn],N;
class node{
public:
int l,r,sl,sr,sum;
}e[50*maxn];
int root[maxn],pos,siz=0;
void build(int& u,int l,int r){
u=++siz;
e[u].l=l;
e[u].r=r;
if(l==r) return;
int mid=(l+r)>>1;
build(e[u].sl,l,mid);
build(e[u].sr,mid+1,r);
}
void insert(int pre,int& u){
u=++siz;
e[u]=e[pre];
e[u].sum++;
if(e[u].l==e[u].r) return;
int mid=(e[u].l+e[u].r)>>1;
if(pos<=mid) insert(e[pre].sl,e[u].sl);
else insert(e[pre].sr,e[u].sr);
}
int Query(int pre,int u,int k){
if(e[u].l==e[u].r) return num[e[u].l];
int sum=e[e[u].sl].sum-e[e[pre].sl].sum;
if(k<=sum) return Query(e[pre].sl,e[u].sl,k);
else return Query(e[pre].sr,e[u].sr,k-sum);
}
int main()
{
N=read();
for(int i=1;i<=N;i++) B[i]=A[i]=read();
sort(B+1,B+1+N);
B[0]=-1;
for(int i=1;i<=N;i++){
if(B[i]!=B[i-1]) numi++;
num[numi]=B[i];
h[i]=numi;
}
build(root[0],1,N);
for(int i=1;i<=N;i++){
pos=h[lower_bound(B+1,B+1+N,A[i])-B];
insert(root[i-1],root[i]);
if(i&1) printf("%d\n",Query(root[0],root[i],(i>>1)+1));
}
return 0;
}