我的treap模板

今年寒假时封装了一个支持查询rank的treap。
然后发现这样无法支持指针的O(1)加减。事实上通过维护指向前继和后继的指针可以实现迭代器的O(1)加减。
今天就又写了一个treap模板,封装性自我感觉良好,有自己的迭代器,而且速度还行,在洛谷的普通平衡树一题中是第16页,总共2700份左右的AC代码。
同时为了测试迭代器和begin指针,还放到快排和堆的模板题里测试,发现我的treap常数是快排3倍不止(还有I/O的硬指标)(我96ms一个点,别人32ms一个点)
在堆的模板题里好像虐stl的priority_queue
然后在本地发现我的multiset插 6105 个int只要0.75秒,而我的treap插 6105 个int却要1秒。连不开O2的stl的multiset都跑不过,人生还有什么希望?
突然感觉可能常数跟lct差不多。
我觉得之所以这么慢,根本原因是我太菜了,直接原因是我的程序不记father的,这样旋转省了常数,但是导致我维护的辅助链表对于程序其他操作的效率没有帮助,本来记了父亲,删除迭代器就可以 O(1) 实现,最多再 O(logn) 自底向上维护一下size(这不必递归,省常数)。因为普通BST是通过取前后继来替代旋转到叶子,而我的辅助链表可以O(1)查一个迭代器的前后继。
而删除键值,甚至可以非递归找到其所在的迭代器,再用前述方法删除,无需递归。
其实插入也可以迭代实现?
然而并不想写,6K代码写完,已经累觉不爱

#include <cstdio>
namespace GenHelper
{
    unsigned z1,z2,z3,z4,b;
    unsigned rand_()
    {
    b=((z1<<6)^z1)>>13;
    z1=((z1&4294967294U)<<18)^b;
    b=((z2<<2)^z2)>>27;
    z2=((z2&4294967288U)<<2)^b;
    b=((z3<<13)^z3)>>21;
    z3=((z3&4294967280U)<<7)^b;
    b=((z4<<3)^z4)>>12;
    z4=((z4&4294967168U)<<13)^b;
    return (z1^z2^z3^z4);
    }
}
void srand(unsigned x)
{using namespace GenHelper;
z1=x; z2=(~x)^0x233333333U; z3=x^0x1234598766U; z4=(~x)+51;}
int rand()
{
    using namespace GenHelper;
    int a=rand_()&32767;
    int b=rand_()&32767;
    return a*32768+b;
}
template<typename T> class treap{
    private:
        struct node{
            node*l,*r,*a[2];
            int p,size,w;
            T v;
            node(T _v):v(_v){l=r=a[0]=a[1]=NULL,p=rand(),w=size=1;}
            void maintain(){size=w+(l?l->size:0)+(r?r->size:0);}
        };
        node*head,*mi;
        void lturn(node* &x){
            node*t=x->r;
            x->r=t->l;
            t->l=x;
            x->maintain();
            t->maintain();
            x=t;
        }
        void rturn(node* &x){
            node*t=x->l;
            x->l=t->r;
            t->r=x;
            x->maintain();
            t->maintain();
            x=t;
        }
        void ins(node* &o,int y,node*fa,int v){
            if(o==NULL){
                o=new node(y);
                o->a[v^1]=fa;
                o->a[v]=fa->a[v];
                if(fa->a[v])fa->a[v]->a[v^1]=o;
                fa->a[v]=o;
            }else if(y>o->v){
                ins(o->r,y,o,1);
                if(o->r->p>o->p)lturn(o);
            }else if(y<o->v){
                ins(o->l,y,o,0);
                if(o->l->p>o->p)rturn(o);
            }else ++o->w;
            o->maintain();
        }
        void del(node* &x,int y){
            if(x==NULL)return;
            if(y>x->v)del(x->r,y);
                else if(y<x->v)del(x->l,y);
                    else{
                        if(x->w>1){
                            --x->size;
                            --x->w;
                            return;
                        }
                        if(x->l==NULL){
                            node*z=x;
                            if(x->a[0])x->a[0]->a[1]=x->a[1];
                            if(x->a[1])x->a[1]->a[0]=x->a[0];
                            x=x->r;
                            delete z;
                            return;
                        }
                        if(x->r==NULL){
                            node*z=x;
                            if(x->a[0])x->a[0]->a[1]=x->a[1];
                            if(x->a[1])x->a[1]->a[0]=x->a[0];
                            x=x->l;
                            delete z;
                            return;
                        }
                        if(x->l->p>x->r->p){                                        
                            rturn(x);
                            del(x->r,y);
                        }else{
                            lturn(x);
                            del(x->l,y);
                        }
                    }
            if(x!=NULL)--x->size;
        }
    public:
        struct iterator{
            node* x;
            iterator(node*_x=NULL):x(_x){}
            bool operator!=(const iterator&rhs)const{return x!=rhs.x;}
            bool operator==(const iterator&rhs)const{return x==rhs.x;}
            T operator*(){return x->v;}
            iterator operator++(){return x=x->a[1];}
            iterator operator++(int){register node*t=x;x=x->a[1];return t;}
            iterator operator--(){return x=x->a[0];}
            iterator operator--(int){register node*t=x;x=x->a[0];return t;}
        };

        treap(){srand(19260817);}
        inline void insert(T x){
            if(head==NULL)mi=head=new node(x);
                else ins(head,x,NULL,0),mi=mi->a[0]?mi->a[0]:mi;
        }
        inline void erase(T x){
            mi=mi->v==x && mi->w==1?mi->a[1]:mi;del(head,x);
        }
        inline void erase(iterator x){
            erase(*x);
        }
        inline T rank(T y){
            register node*x=head;
            register int ans=0,s;
            while(x!=NULL){
                s=x->l?x->l->size:0;
                if(y==x->v)return ans+s+1;
                if(y>x->v){
                    ans+=s+x->w;
                    x=x->r;
                }else x=x->l;
            }
            return ans+1;
        }
        inline T kth(T y){
            register node*x=head;
            register int u,v;
            for(;;){
                u=x->l?x->l->size:0,v=x->w;
                if(u<y && u+v>=y)return x->v;
                if(y>u+v){
                    x=x->r;
                    y-=u+v;
                }else x=x->l;
            }
        }
        inline iterator prec(T y){
            node*x=head,*t=NULL;
            while(x!=NULL){
                if(x->v>=y)x=x->l;
                    else{
                        t=x;
                        x=x->r;
                    }
            }     
            return t;
        }
        inline iterator succ(T y){
            node*x=head,*t=NULL;
            while(x)
                if(x->v<=y)x=x->r;
                    else{
                        t=x;
                        x=x->l;
                    }
            return t;
        }
        inline iterator find(T v){
            register node*x=head;
            for(;x && x->v!=v;x=v<x->v?x->l:x->r);
            return x;
        }
        inline int count(T v){
            register iterator t=find(v);return t.x?t.x->w:0;
        }
        inline iterator begin(){return mi;}
        inline iterator end(){return NULL;}
};
treap<int> t;
int main(){
    int n;
    scanf("%d",&n);
    for(int i=1;i<=n;i++){
        int x,y;
        scanf("%d%d",&x,&y);
        if(x==1)t.insert(y);
        if(x==2)t.erase(y);
        if(x==3)printf("%d\n",t.rank(y));
        if(x==4)printf("%d\n",t.kth(y));
        if(x==5)printf("%d\n",*t.prec(y));
        if(x==6)printf("%d\n",*t.succ(y));
    }
    return 0;
}

upd:1.1版本
今天把老版本里极不优美的递归实现改成了迭代实现,好像快了一些?(224ms->196ms,都没有I/O优化)
除了快一些,好像代码也短了一些,实测40万次插入+查rank,本机900ms左右,作为对比的__gnu_pbds::rb_tree_tag(我不跟stl的set比,因为stl的set不能查rank),同样操作要1560ms左右,真开心。
附上当前版本的几个要点(或问题)
1.需要重载==,<,<=,>,>=。本来可以不重载,但重载可以给写模板的我省事
2.本模板使用游程编码,重复键值会导致计数器++,而不是使节点数增加
3.本模板的rk返回的是小于给定键值的元素个数再加1
4.本模板的erase即使是键值,也只是使计数器–,因为目前还没有彻底删除的需求
5.此版本目前不支持–s.end()
好像手写的treap比红黑树慢很多?看来理论分析出来的常数在实践中还是有意义的。
下面就是1.1版本的代码

#include<cstdio>
#include<cctype>
#include<ctime>
template<typename T> class treap{
    private:
        struct node{
            node*ch[2],*a[2],*fa;
            unsigned int p;
            int size,w;
            T v;
            node(T _v):v(_v){
                static unsigned int seed=19260817;
                ch[0]=ch[1]=a[0]=a[1]=fa=NULL;
                p=seed^=seed>>13,seed^=seed<<21,seed^=seed>>17,seed^=seed<<24;
                w=size=1;
            }
            void maintain(){size=w+(ch[0]?ch[0]->size:0)+(ch[1]?ch[1]->size:0);}
            inline int lr(){return fa->ch[1]==this;}
        };
        node*rt,*mi;
        inline void rotate(node*x){
            node*y=x->fa,*z=y->fa;
            if(z)z->ch[y->lr()]=x;
            int o=x->lr();
            x->fa=z,y->fa=x;
            y->ch[o]=x->ch[!o];
            if(x->ch[!o])x->ch[!o]->fa=y;
            x->ch[!o]=y;y->maintain();x->maintain();
        }
    public:
        struct iterator{
            node* x;
            iterator(node*_x=NULL):x(_x){}
            bool operator!=(const iterator&rhs)const{return x!=rhs.x;}
            bool operator==(const iterator&rhs)const{return x==rhs.x;}
            T operator*(){return x->v;}
            iterator operator++(){return x=x->a[1];}
            iterator operator++(int){node*t=x;x=x->a[1];return t;}
            iterator operator--(){return x=x->a[0];}
            iterator operator--(int){node*t=x;x=x->a[0];return t;}
        };
        inline iterator find(T x){
            for(node*i=rt;i!=NULL;i=x<i->v?i->ch[0]:i->ch[1])if(i->v==x)return i;
            return NULL;
        }
        inline void insert(T x){
            if(!rt){
                rt=mi=new node(x);
                return;
            }
            node*i=rt,*j=NULL;int o;
            while(i!=NULL){
                j=i;++i->size;
                if(x==i->v){++i->w;return;}
                i=x<i->v?i->ch[o=0]:i->ch[o=1];
            }
            i=new node(x);i->fa=j;
            j->ch[o]=i;
            if(j->a[o])j->a[o]->a[!o]=i;
            i->a[o]=j->a[o];
            i->a[!o]=j;
            j->a[o]=i;
            if(mi==NULL || x<mi->v)mi=i;
            while(i->fa!=NULL && i->fa->p<i->p)
                rotate(i);
            if(i->fa==NULL)rt=i;
        }
        inline void erase(node*x){
            if(x->w>1){
                --x->w;
                for(;x!=NULL;x=x->fa)
                    --x->size;
                return;
            }
            if(x==rt && rt->size==1){delete rt;mi=rt=NULL;return;}
            if(x->a[0])x->a[0]->a[1]=x->a[1];
            if(x->a[1])x->a[1]->a[0]=x->a[0];
            if(x==mi)mi=x->a[1];
            if(!x->ch[0]){
                if(x->fa)x->fa->ch[x->lr()]=x->ch[1];
                if(x->ch[1])x->ch[1]->fa=x->fa;
                for(node*y=x->fa;y;y=y->fa)--y->size;delete x;
            }else{
                node*y=x->a[0],*u=y->ch[0];
                if(u!=NULL)u->fa=y->fa,y->fa->ch[y->lr()]=u;else y->fa->ch[y->lr()]=NULL;
                x->v=y->v,x->w=y->w,x->a[0]=y->a[0],x->a[1]=y->a[1];
                if(y->a[0])y->a[0]->a[1]=x;
                if(y->a[1])y->a[1]->a[0]=x;
                node*z=y->fa;
                for(;z!=x;z=z->fa)z->size-=y->w;
                for(;z;z=z->fa)--z->size;
                delete y;
            }   
        }
        inline void erase(T x){
            erase(find(x).x);
        }
        inline void erase(iterator x){
            erase(x.x);
        }
        inline int rank(T y){
            node*x=rt;
            int ans=0,s;
            while(x){
                s=x->ch[0]?x->ch[0]->size:0;
                if(y==x->v)return ans+s+1;
                if(y>x->v){
                    ans+=s+x->w;
                    x=x->ch[1];
                }else x=x->ch[0];
            }
            return ans+1;
        }
        inline T kth(T y){
            node*x=rt;
            int u,v;
            for(;;){
                u=x->ch[0]?x->ch[0]->size:0,v=x->w;
                if(u<y && u+v>=y)return x->v;
                if(y>u+v){
                    x=x->ch[1];
                    y-=u+v;
                }else x=x->ch[0];
            }
        }
        inline iterator prec(T y){
            node*x=rt,*t=NULL;
            while(x)
                if(x->v<y)t=x,x=x->ch[1];
                    else x=x->ch[0];
            return t;
        }
        inline iterator succ(T y){
            node*x=rt,*t=NULL;
            while(x)
                if(y<x->v)t=x,x=x->ch[0];
                    else x=x->ch[1];
            return t;
        }
        inline int count(T v){
            iterator t=find(v);return t.x!=NULL?t.x->w:0;
        }
        inline iterator begin(){return mi;}
        inline iterator end(){return NULL;}
        inline bool empty(){return rt==NULL;}
        inline int size(){return rt?rt->size:0;}
};
treap<int> t;
inline void read(int&x){
    char c=getchar();int f=1;
    for(;!isdigit(c);c=getchar())f=c=='-'?-1:f;
    for(x=0;isdigit(c);c=getchar())x=x*10+c-48;x*=f;
}
int main(){
    int n;
    scanf("%d",&n);
    for(int i=1;i<=n;i++){
        int x,y;
        scanf("%d%d",&x,&y);
        if(x==1)t.insert(y);
        if(x==2)t.erase(y);
        if(x==3)printf("%d\n",t.rank(y));
        if(x==4)printf("%d\n",t.kth(y));
        if(x==5)printf("%d\n",*t.prec(y));
        if(x==6)printf("%d\n",*t.succ(y));
    }
    return 0;
}
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值