主席树(可持久化线段树)入门专题

1.poj 2104 查询区间第k小。

主席树其实相当于建立了n棵线段树,第i棵线段树是根据区间【1,i】按值建立的。对于每一棵线段树我们记录它对应的区间每个数出现的次数,所以首先要对所有的数离散化。

先考虑最简单的情况,只查询【1,n】的第k小,对于【1,n】我们按值建立一棵线段树,对于a[i]我们在位置a[i]上加1。查询第k小那么先看左子区间出现了多少个数cnt,假设左区间出现的数cnt>=k,那么直接递归到左区间查询(因为是按值建立的,左区间的数肯定小于右区间),否则递归到右区间查询第k-cnt小(左区间已经有了最小的cnt个数了)


对于任意区间查询【l,r】,我们只需要比较第l-1棵线段树和第r棵线段树,【l,r】之间的数就是第r棵线段树相比于第l-1棵多出来的数。只需要对比两颗树同一个节点,对比到哪个数为止第r棵比第l-1棵刚好多出k个数。(先比较左区间cntr-cntl,cntr-cntl>=k,则递归到左区间,否则递归查询右区间k-cntr-cntl)


主席树就相当于n棵线段树,但是对比建立在【1,i】的线段树和【1,i+1】的线段树,只多出了一个值,也就是相当于单点更新他们之间只有logn个节点是不同的,所以可以将【1,i+1】的一些节点指针指向前一棵的共同部分。每次新增的空间只需要logn。


代码是根据kuangbin模板抄的。。。。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cctype>
#include <string>
#include <vector>
#include <map>
#include <set>
#include <vector>
#include <queue>
#include <stack>
#include <algorithm>
using namespace std;

const int maxn=1e5+10;
const int M=maxn*30;

int n,q,m,tot;
int a[maxn], t[maxn];
int T[maxn], lson[M], rson[M], c[M];

void init_hash()
{
    for(int i=1; i<=n; i++)
        t[i]=a[i];
    sort(t+1, t+n+1);
    m=unique(t+1, t+n+1)-t-1;
}

int hash(int x)
{
    return lower_bound(t+1, t+1+m, x)-t;
}

int build(int l, int r)
{
    int rt=tot++;
    c[rt]=0;
    if(l!=r){
        int mid=(l+r)>>1;
        lson[rt]=build(l,mid);
        rson[rt]=build(mid+1, r);
    }
    return rt;
}


int update(int rt, int pos, int val)
{
    int newrt=tot++,tmp=newrt;
    c[newrt]=c[rt]+val;
    int l=1, r=m;
    while(l<r){
        int mid=(l+r)>>1;
        if(pos<=mid){
            lson[newrt]=tot++; rson[newrt]=rson[rt];
            newrt=lson[newrt]; rt=lson[rt];
            r=mid;
        }
        else{
            rson[newrt]=tot++; lson[newrt]=lson[rt];
            newrt=rson[newrt]; rt=rson[rt];
            l=mid+1;
        }
        c[newrt]=c[rt]+val;
    }
    return tmp;
}

int query(int lrt, int rrt, int k)
{
    int l=1, r=m;
    while(l<r){
        int mid=(l+r)>>1;
        if(c[lson[lrt]]-c[lson[rrt]]>=k){
            r=mid;
            lrt=lson[lrt];
            rrt=lson[rrt];
        }
        else{
            l=mid+1;
            k-=c[lson[lrt]]-c[lson[rrt]];
            lrt=rson[lrt];
            rrt=rson[rrt];
        }
    }

    return l;
}

int main()
{
    while(cin>>n>>q){
        
        for(int i=1; i<=n; i++)
            scanf("%d", a+i);
        init_hash();
        tot=0;
        T[n+1]=build(1,m);

        for(int i=n; i; i--){
            int pos=hash(a[i]);
            T[i]=update(T[i+1], pos, 1);
        }
        while(q--){
            int l,r,k;
            scanf("%d%d%d", &l, &r, &k);
            printf("%d\n", t[query(T[l], T[r+1], k)]);
        }
    }
    return 0;
}



2.hdu 4417 区间查询<=H的数有多少个

查询【l,r】区间只需要将第r棵线段树【0,H】区间的总数减去第l-1棵的就行了。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cctype>
#include <string>
#include <vector>
#include <map>
#include <set>
#include <vector>
#include <queue>
#include <stack>
#include <algorithm>
using namespace std;
const int maxn=1e5+10;
const int maxm=maxn*30;

int n,m, N;
int a[maxn],b[2*maxn];
int T[maxn],tot;
int lson[maxm],rson[maxm], cnt[maxm];

int build(int l, int r)
{
    int rt=tot++;
    cnt[rt]=0;
    if(l==r) return rt;
    int mid=(l+r)>>1;
    lson[rt]=build(l, mid);
    rson[rt]=build(mid+1, r);
    return rt;
}


int update(int rt, int pos, int v)
{
    int newrt=tot++, ret=newrt;
    int l=1, r=N;
    cnt[newrt]=cnt[rt]+v;
    while(l<r){
        int mid=(l+r)>>1;
        if(pos<=mid){
            lson[newrt]=tot++; rson[newrt]=rson[rt];
            newrt=lson[newrt]; rt=lson[rt];
            r=mid;
        }
        else{
            lson[newrt]=lson[rt]; rson[newrt]=tot++;
            newrt=rson[newrt]; rt=rson[rt];
            l=mid+1;
        }
        cnt[newrt]=cnt[rt]+v;
    }

    return ret;
}


int query(int lrt, int rrt, int pos)
{
    int ret=0;
    int l=1,r=N;
    while(l<r){
        int mid=(l+r)>>1;
        if(pos<=mid){
            lrt=lson[lrt]; rrt=lson[rrt];
            r=mid;
        }
        else{
            ret+=cnt[lson[rrt]]-cnt[lson[lrt]];
            lrt=rson[lrt]; rrt=rson[rrt];
            l=mid+1;
        }
    }
    ret+=cnt[rrt]-cnt[lrt];
    return ret;
}


int l[maxn], r[maxn], h[maxn];
int main()
{
    int t;
    cin>>t;
    for(int tt=1; tt<=t; tt++){
        cin>>n>>m;
        for(int i=1; i<=n; i++){
            scanf("%d", a+i);
            b[i]=a[i];
        }

        for(int i=0; i<m; i++){
            scanf("%d%d%d", l+i, r+i, h+i);
            b[n+1+i]=h[i];
        }
        sort(b+1, b+1+n+m);
        N=unique(b+1, b+1+n+m)-b-1;

        tot=0;
        T[0]=build(1,N);
        for(int i=1; i<=n; i++){
            int v=lower_bound(b+1, b+1+N, a[i])-b;
            T[i]=update(T[i-1], v, 1);
        }

        printf("Case %d:\n", tt);
        for(int i=0; i<m; i++){
            int v=lower_bound(b+1, b+1+N, h[i])-b;
            printf("%d\n", query(T[l[i]], T[r[i]+1], v));
        }


    }
    return 0;
}



3.hdu 4348 可持久化线段树,区间更新,不下放的懒惰标记(空间优化)

主席树其实就是可持久化线段树。可持久化就是每次修改操作,尽量用新节点表示而不是直接修改原来的点,这样所有的历史版本都得以保留。

主要麻烦的就是区间更新。区间更新对于完全覆盖的区间要用lazy标记。但是每次lazy下放的时候两个子区间都发生修改需要创造两个新的节点,这样到最后下放到最后一层相当于消耗了O(n)个新节点,空间会爆。

这道题用的空间优化就是不下放标记。标记就打在那个区间节点上,而查询的时候,往下递归时遇到标记就累加,最后把标记的影响加到总答案里。这样就不需要创造那么多新节点了。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cctype>
#include <string>
#include <vector>
#include <map>
#include <set>
#include <vector>
#include <queue>
#include <stack>
#include <algorithm>
using namespace std;
const int maxn=1e5+1000;
const int maxm=30*maxn;
typedef long long LL;
int n, m;
int a[maxn];
int T[maxn], tot=0;
int lson[maxm],rson[maxm], lazy[maxm];
LL sum[maxm];

void push_up(int rt, int l, int r)
{
    sum[rt]=sum[lson[rt]]+sum[rson[rt]]+(LL)lazy[rt]*(r-l+1);
}


int build(int l, int r)
{
    int rt=tot++;
    lazy[rt]=0;
    if(l==r){
        sum[rt]=a[l];
        return rt;
    }

    int mid=(l+r)>>1;
    lson[rt]=build(l,mid);
    rson[rt]=build(mid+1, r);
    push_up(rt, l, r);
    return rt;
}

int update(int rt, int l, int r, int ll, int rr, int v)
{
    int newrt=tot++;
    lazy[newrt]=lazy[rt];
    if(ll<=l && r<=rr){
        lson[newrt]=lson[rt], rson[newrt]=rson[rt];
        lazy[newrt]=lazy[rt]+v;
        sum[newrt]=sum[rt]+(LL)v*(r-l+1);
        return newrt;
    }


    int mid=(l+r)>>1;
    if(rr<=mid){
        rson[newrt]=rson[rt];
        lson[newrt]=update(lson[rt], l, mid, ll, rr, v);
    }
    else if(ll>mid){
        lson[newrt]=lson[rt];
        rson[newrt]=update(rson[rt], mid+1, r,ll, rr, v);
    }
    else{
        lson[newrt]=update(lson[rt], l, mid, ll, mid, v);
        rson[newrt]=update(rson[rt], mid+1, r, mid+1, rr, v);
    }

    push_up(newrt, l, r);
    return newrt;
}


LL query(int rt, int l, int r, int ll, int rr, int  la)
{

    if(ll<=l && r<=rr){
        return sum[rt]+(LL)la*(r-l+1);
    }
    la+=lazy[rt];

    int mid=(l+r)>>1;
    if(rr<=mid)
        return query(lson[rt], l, mid, ll, rr, la);
    else if(ll>mid)
        return query(rson[rt], mid+1, r, ll, rr, la);
    else
        return query(lson[rt], l, mid, ll, mid, la)+query(rson[rt], mid+1, r, mid+1, rr, la);
}


int main()
{
    while(cin>>n>>m){
        for(int i=1; i<=n; i++)
            scanf("%d", a+i);

        tot=0;
        T[0]=build(1,n);
        int tag=0;
        while(m--){
            char s[5];
            int l,r,d;
            scanf("%s%d", s, &l);
            if(s[0]=='B'){
                tag=l;
            }
            else{
                scanf("%d", &r);
                if(s[0]=='Q'){
                    LL res=query(T[tag], 1, n, l, r, 0);
                    printf("%I64d\n", res);
                }
                else{
                    scanf("%d", &d);
                    if(s[0]=='C'){
                        tag++;
                        T[tag]=update(T[tag-1], 1, n, l, r, d);
                    }
                    else{
                        LL res=query(T[d], 1, n, l, r, 0);
                        printf("%I64d\n", res);
                    }
                }
            }
        }
    }
    return 0;
}



没有更多推荐了,返回首页