参考文章ACM数据结构(一)——主席树
可持久化专题(一)——浅谈主席树:可持久化线段树
简介
主席树为什么叫主席树?据说因为它是一个名字缩写为HJT的神犇发明的,与当时主席的名字缩写一样,主席树实质上就是一棵可持久化线段树,它的具体实现可以看下面。
值域线段树
要学主席树,我们就要先学值域线段树。
值域线段树的区间存的并不是节点信息,而是在值在某一范围内的数的个数。
也就是说跟区间线段树不同的是每个节点的区间表示的不再是数组的一个区间的信息,而是数组中的值域内的值的个数。
值域线段树的查询也挺简单的,若要查询这段区间内的第k大,只要比较当前元素的左子树大小加1(1是当前元素本身的大小)与询问的k,若大于等于,就访问左子树,否则将k减去当前元素的左子树大小加1,然后访问右子树。
还有一个问题,就是值域线段树存储的区间范围是固定的,所以如果要查询区间第k大,我们就不能只用一棵值域线段树。
考虑建n棵值域线段树,每棵值域线段树存储区间[1,i]的信息,这样一来,要查询[l,r]的第k大时,只要在查询的过程中,将第r棵值域线段树的信息减去第l−1棵值域线段树的信息即可,这利用了前缀和的思想。
主席树
知道了值域线段树,我们就可以开始尝试实现主席树了。
来研究一下下面两棵分别存储[1,3]和[1,4]区间信息的值域线段树(圆圈中为以该节点为根的子树大小)。
仔细观察可得,我们每次新加入一个节点,有影响的只有图中标红的节点。
再仔细观察一下,这些节点都在一条链上。
那么,我们就会有一个大胆的想法:每次只新建一条链而不是一棵树,就像下面这样
这就是传说中的主席树了。
代码
- 创建根节点,左右儿子结点数组
int tot=0,rt[maxn*20],lson[maxn*20],rson[maxn*20],v[maxn*20],lz[maxn*20],a[maxn];
tot是每次新建的结点编号。
rt[i]是第i棵线段树的根节点的编号。
lson[x]和rson[x]是结点x的左右儿子结点的编号。
v[x]是结点x代表的区间的和。
lz[x]是结点x的懒惰(lazy)值。
a[i]是初始的第i个位置的值。
因为结点每次至多更新O(logn)个,所以数组范围应该在原来的20-50倍左右。
2. 区间更新的pushup和pushdown
void push_up(int x){
v[x]=v[lson[x]]+v[rson[x]];
}
void push_down(int x,int len){
if(lz[x]){
v[lson[x]]+=(len>>1)*lz[x];
v[rson[x]]+=(len-(len>>1))*lz[x];
lz[lson[x]]+=(len>>1)*lz[x];
lz[rson[x]]+=(len-(len>>1))*lz[x];
lz[x]=0;
}
}
线段树更新基础
3. 建树
void build(int &x,int l,int r){
x=++tot;
lz[x]=0;
if(l==r){
v[x]=a[l];
return;
}
int mid=l+r>>1;
build(lson[x],l,mid);
build(rson[x],mid+1,r);
push_up(x);
}
和线段树的思想是一样的,只是在调用过程中,我们以引用的形式,实现对rt,lson,rson的更新。
建树的调用如下:
build(rt[0],1,n);
- 更新
void update(int L,int R,int l,int r,int &x,int last,int val){
x=++tot;
lson[x]=lson[last];rson[x]=rson[last];lz[x]=lz[last];v[x]=v[last];
if(L<=l&&R>=r){
v[x]+=(r-l+1)*val;lz[x]+=val;
return;
}
push_down(x,r-l+1);
int mid=l+r>>1;
if(L<=mid) update(L,R,l,mid,lson[x],lson[last],val);
if(R>mid) update(L,R,mid+1,r,rson[x],rson[last],val);
push_up(x);
}
第1行开辟了新的结点,第2行进行了结点复用,last就是上一棵线段树的结点,从根节点向下更新。
更新的调用如下:
update(x,y,1,n,rt[i],rt[i-1],w);
- 查询
int query(int L,int R,int l,int r,int x){
if(L<=l&&R>=r){
return v[x];
}
push_down(x,r-l+1);
int mid=l+r>>1,sum=0;
if(L<=mid) sum+=query(L,R,l,mid,lson[x]);
if(R>mid) sum+=query(L,R,mid+1,r,rson[x]);
push_up(x);
return sum;
}
查询就是简单的区间查询。
查询的调用如下:
query(x,y,1,n,rt[i]);