关于treap启发式合并的一点脑洞(以bzoj2809为例)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/scpointer/article/details/54786203

首先我知道bzoj2809正解应该是可并堆,之所以写treap启发式合并单纯只是因为这个脑洞…

首先我们有两个treap,分别是 A 和 B ,它们的节点数分别为 n 和 m (n<m)

网上普通的启发式合并是把 A 中的结点一个个扔进 B 里,这样复杂度是 O(nlogm) 的。

然而 zyqn 告诉我 myy 说这个事情可以做到 O(nlogmn) ,这就非常一颗赛艇了。

因为这意味着,如果你有 n 个只有一个结点的 treap,然后我每次指定两个让你合并,顺便询问询问,最后合并成一个 treap, 那么这个事情可以做到 O(nlogn) 而不是 O(nlog2n)

然后我搞了搞发现我可能写挂了,不过还是决定放上来讲讲。

具体来说是把 B 那个 treap 拆成一个序列,然后塞进 A 相应的地方。就像这样:

示意图

但是光这么搞会出一些问题,比如有 n 个结点,权值分别为 1~n ,那把第二个,第三个…第 n 个点依次塞进第一个点的 treap 里,这个 treap 就会变成一条链。补救办法是在某些时候把需要 merge 的 A 和 B 两个 treap 都拆成序列,然后暴力重构整个 treap 的这个子树。

然后问题就是什么时候暴力重构了,这个可以参考普通 treap,如果两个 treap 的结点数分别为 n 和 m (n<m),那么我们以 nn+m 的概率重构它。还有如果最开始被拆成序列的那个 treap 如果在 merge 到某个点的时候,序列长度比当前没被拆的那个 treap 的大小要大了,那暴力重构肯定不会影响复杂度。

这么一通搞完以后我发现复杂度我并不会证。也许是 O(nlogmn) ,也许还是 O(nlogm)

不过这样写有一个好,就是如果这个 treap 不需要可持久化的话,空间复杂度是 O() 的,不会出现卡空间的问题。但是要多写这么多函数啊喂,如果是OI比赛调不出来怎么办

附代码

#include <cstdio>
#include <algorithm>
#define N 100050
#define NDS 1000050
#define INF 1000000001
#define swap(a,b) std::swap(a,b)
//namespace fastin
//{
    #define BUF 100005
    char __s[BUF],*_ss,*_st;
    inline char gchar()
    {
        if(_ss==_st) _st=(_ss=__s)+fread(__s,sizeof(char),BUF-5,stdin);
        return *_ss++;
    }
    inline int RD()
    {
        int res;char cr;
        while( (cr=gchar())<'0' || cr>'9');res=cr-'0';
        while( (cr=gchar())>='0' && cr<='9') res=res*10+cr-'0';
        return res;
    }
    #undef BUF
//}

int trand()
{
    static long long x=2947183,a=3798293,b=1000000009;
    return x=(x*a+b)&0x7fffffff;
}
int trand(int lim){return trand()%lim;}

#define ls p->son[0]
#define rs p->son[1]
struct Node *null;
struct Node{
    int sz,val;
    long long totval;
    Node *son[2];
    void update()
    {
        sz=(son[0]->sz)+(son[1]->sz)+1;
        totval=(son[0]->totval)+(son[1]->totval)+val;
        if(totval>INF) totval=INF;
    }
}_pol[NDS],*_ndcnt;
Node* newNode(){return _ndcnt++;}
void init()
{
    null=new Node();
    *null=(Node){0,0,0,{NULL,NULL}};
    _ndcnt=_pol;
}

Node *que[N],*q2[N];
int ql,qr,q2top;
void takeApart(Node *p)//把treap p拆成序列
{
    if(p==null) return;
    takeApart(ls);
    que[++qr]=p;
    takeApart(rs);
}
void combine(Node *p,int &l,int r)//把treap p拆成序列,并且和que里的序列合成一个新序列
{
    if(p==null) return;
    combine(ls,l,r);
    while(l<=r && (que[l]->val)<(p->val))
        q2[++q2top]=que[l++];
    q2[++q2top]=p;
    combine(rs,l,r);
}
Node* rebuild(int l,int r)//暴力重建,把序列变成一个treap
{
    if(l>r) return null;
    int mid=(l+r)>>1;
//  Node *p=newNode();*p=*q2[mid];
    Node *p=q2[mid];
    ls=rebuild(l,mid-1);
    rs=rebuild(mid+1,r);
    p->update();
    return p;
}
Node* rebuild(Node *p,int l,int r)//暴力重建treap p和序列que(中的[l,r]这一段)
{
    q2top=0;
    combine(p,l,r);
    while(l<=r) q2[++q2top]=que[l++];
    return rebuild(1,q2top);
}
int find(int l,int r,int val)
{
    //判定merge下传的时候应该把que序列中的前一半扔给左儿子,后一半扔给右儿子
    //这个函数是用来找前一半和后一半的分界点的
    //不要问我为什么不直接二分
    for(int del=0;;del++)
    {
        if(val<(que[l+del]->val))
            return l+del-1;
        if(val>=(que[r-del]->val))
            return r-del;
    }
}
Node* merge(Node *rt,int l,int r)
{
    //合并treap rt和序列que(中的[l,r]这一段)。这个代码的核心。
//  Node *p=newNode();*p=*rt;
    Node *p=rt;
    if(l>r) return p;
    if(p==null) return rebuild(p,l,r);
    int size1=p->sz,size2=r-l+1;
    if(size1<=size2 || trand(size1+size2)<size2)
        return rebuild(p,l,r);
    int mid=find(l,r,p->val);
    ls=merge(rt->son[0],l,mid);
    rs=merge(rt->son[1],mid+1,r);
    p->update();
    return p;
}
Node* merge(Node *a,Node *b)//合并a,b两个treap,其中b被拆成序列
{
    if((a->sz)<(b->sz)) swap(a,b);
    ql=1;qr=0;
    takeApart(b);
    return merge(a,ql,qr);
}

/*int query(Node* p,int maxcost)
{
    if(p==null) return 0;
    if(p->totval<=maxcost) return p->sz;
    if(ls->totval<=maxcost)
    {
        if((ls->totval)+(p->val)<=maxcost)
            return ls->sz+1+query(rs,(maxcost)-(ls->totval)-(p->val));
        else return ls->sz;
    }
    else return query(ls,maxcost);
}
*/
int query(Node* &p,int maxcost)//本来应该用上面那个query(),这个函数是用题目性质加了一些常数优化
{
    if(p==null) return 0;
    if(p->totval<=maxcost) return p->sz;
    int res;
    if(ls->totval<=maxcost)
    {
        if((ls->totval)+(p->val)<=maxcost)
            res=ls->sz+1+query(rs,(maxcost)-(ls->totval)-(p->val));
        else res=ls->sz,p=ls;
    }
    else res=query(ls,maxcost),p=ls;
    if(p!=null) p->update();
    return res;
}

int fa[N],cost[N],leadership[N];
int fson[N],bro[N];
Node *rot[N];
int main()
{
    init();
    int n,maxcost;n=RD();maxcost=RD();
    for(int i=1;i<=n;i++)
    {
        fa[i]=RD();cost[i]=RD();leadership[i]=RD();
        bro[i]=fson[fa[i]];
        fson[fa[i]]=i;
    }
    long long ans=0;
    for(int i=n;i>=1;i--)
    {
        rot[i]=newNode();
        *rot[i]=(Node){1,cost[i],cost[i],{null,null}};
        for(int j=fson[i];j;j=bro[j])
            rot[i]=merge(rot[i],rot[j]);
        long long res=(long long)leadership[i]*query(rot[i],maxcost);
        if(res>ans) ans=res;
    }
    printf("%lld\n",ans);
}

展开阅读全文

没有更多推荐了,返回首页