[bzoj4923][splay]K小值查询

版权声明:他无力阻止你的转载 https://blog.csdn.net/qq_36993218/article/details/79980303

4923: [Lydsy1706月赛]K小值查询

Time Limit: 15 Sec Memory Limit: 256 MB
Submit: 415 Solved: 121
[Submit][Status][Discuss]
Description

维护一个长度为n的正整数序列a_1,a_2,…,a_n,支持以下两种操作:
1 k,将序列a从小到大排序,输出a_k的值。
2 k,将所有严格大于k的数a_i减去k。
Input

第一行包含两个正整数n,m(1<=n,m<=100000),分别表示序列的长度和操作的个数。
第二行包含n个正整数a_1,a_2,…,a_n(1<=a_i<=10^9),分别表示序列中的每个元素。
接下来m行,每行两个正整数op(1<=op<=2),k,若op=1,则1<=k<=n;若op=2,则1<=k<=10^9;依次描述每个操作。
Output

输出若干行,对于每个询问输出一行一个整数,即第k小的值。
Sample Input

4 5

1 5 6 12

2 5

1 1

1 2

1 3

1 4
Sample Output

1

1

5

7

HINT

Source

本OJ付费获取

sol:

考虑用splay来维护数的排名。感官上有些数修改之后的相对排名时不会变的。
考虑1到k的数不会被修改。k+1到2k的数修改之后会和前面的数排名交叉,2k+1到inf的数修改后相对排名不变。后面的数打个标记即可。中间的数至少会减少一半,暴力修改插入即可。

像我这样不回收内存的,记得把内存开大点。我调了2h。。

#include<cstdio>
#include<algorithm>
#include<string>
#include<cstring>
#include<cstdlib>
#include<cmath>
#include<iostream>
using namespace std;
typedef long long ll;
typedef double db;
int n,m;
inline int read()
{
    char c;
    int res,flag=0;
    while((c=getchar())>'9'||c<'0') if(c=='-')flag=1;
    res=c-'0';
    while((c=getchar())>='0'&&c<='9') res=(res<<3)+(res<<1)+c-'0';
    return flag?-res:res;
}
const int N=5e6+7;
int val[N],lc[N],rc[N],fa[N],siz[N],cnt[N];
int tag[N];
inline void add(int x,int y)
{
    if(x)
    {
        val[x]+=y;
        tag[x]+=y;
    }
}
inline void tag_down(int x)
{
    if(tag[x])
    {
        add(lc[x],tag[x]);
        add(rc[x],tag[x]);
        tag[x]=0;
    }
}
inline void updata(int x)
{
    siz[x]=siz[lc[x]]+siz[rc[x]]+cnt[x];
}
inline void rotate(int x)
{
    int y=fa[x],z=fa[y];
    int b=lc[y]==x?rc[x]:lc[x];
    if(b) fa[b]=y;
    fa[x]=z;fa[y]=x;
    if(z)
    {
        if(lc[z]==y) lc[z]=x;
        else rc[z]=x;
    }
    if(rc[y]==x) rc[y]=b,lc[x]=y;
    else lc[y]=b,rc[x]=y;
    updata(y); 
}
int sta[N],rt;
inline void splay(int x,int f)
{
    sta[sta[0]=1]=x;
    for(int y=x;fa[y];y=fa[y]) sta[++sta[0]]=fa[y];
    while(sta[0]) tag_down(sta[sta[0]--]);
    while(fa[x]!=f)
    {
        if(fa[fa[x]]!=f)
        {
            if((lc[fa[fa[x]]]==fa[x])==(lc[fa[x]]==x)) rotate(fa[x]);
            else rotate(x);
        }
        rotate(x);
    } 
    updata(x);
    if(!f) rt=x;
}
inline int kth(int x,int y)
{
    while(true)
    {
        tag_down(x); 
        if(siz[lc[x]]>=y) x=lc[x];
        else if(siz[lc[x]]+cnt[x]>=y) return val[x];
        else y-=siz[lc[x]]+cnt[x],x=rc[x];
    }
}
int tot;
inline int insert(int x,int y)
{
    int now=rt,now2;
    while(now)
    {
        now2=now;
        tag_down(now);
        if(val[now]<x) now=rc[now];
        else if(val[now]>x) now=lc[now];
        else
        {
            cnt[now]+=y;
            splay(now,0);
            return now;
        }
    }
    ++tot;
    if(val[now2]<x) rc[now2]=tot;
    else lc[now2]=tot;
    val[tot]=x;
    fa[tot]=now2;
    cnt[tot]=y;
    splay(tot,0);
    return tot;
}
int a[N];
inline void build(int &k,int l,int r)
{
    int mid=l+r>>1;
    k=mid;
    val[k]=a[k];
    if(l<=mid-1) build(lc[mid],l,mid-1);
    if(mid+1<=r) build(rc[mid],mid+1,r);
    if(lc[k]) fa[lc[k]]=k;
    if(rc[k]) fa[rc[k]]=k; 
    updata(k);
}
int b;
int q[N],qr;
inline void sc(int x)
{
    if(!x) return;
    tag_down(x);
    if(lc[x]) sc(lc[x]);
    if(rc[x]) sc(rc[x]);
    q[++qr]=x;
    val[x]-=b;
    lc[x]=rc[x]=fa[x]=0;
}
inline void del(int x)
{
    if(cnt[x]>1)
    {
        cnt[x]--;
        return;
    }
    splay(x,0);
    if(!lc[x])
    {
        rt=rc[x];
        fa[rc[x]]=0;
        return;
    }
    if(!rc[x])
    {
        rt=lc[x];
        fa[lc[x]]=0;
        return;
    }
    rt=rc[x];
    fa[rc[x]]=0;
    int y=rc[x];
    tag_down(y);
    while(lc[y])
    {
        y=lc[y];
        tag_down(y);
    }
    lc[y]=lc[x];
    fa[lc[x]]=y;
    splay(lc[x],0);
}
inline void debug(int x)
{
    if(!x) return;
    tag_down(x);
    if(lc[x]) debug(lc[x]);
    printf("%d %d\n",val[x],cnt[x]);
    if(rc[x]) debug(rc[x]);
}
int main()
{
//  freopen("kth.in","r",stdin);
//  freopen("kth.out","w",stdout);
    n=read();
    m=read();
    for(int i=1;i<=n;++i) a[i]=read();
    a[0]=-1;
    a[++n]=0;
    sort(a+1,a+1+n);
    int nn=n;
    n=0;
    for(int i=1;i<=nn;++i)
    {
        if(a[i]!=a[i-1]) a[++n]=a[i];
        cnt[n]++;
    }
    build(rt,1,n);
    tot=n;
    int a;
    for(int j=1;j<=m;++j)
    {
        a=read();
        b=read();
//      debug(rt);
//      printf("\n");
        if(a==1)
        printf("%d\n",kth(rt,b+1));
        else
        {
            int x=insert(b,1);
            int y=insert(b*2+1,1);
            splay(x,0);
            splay(y,x);
            qr=0;
            sc(lc[y]);
            lc[y]=0;
            for(int i=1;i<=qr;++i) insert(val[q[i]],cnt[q[i]]);
            del(x);
            del(y);
            x=insert(b*2,1);
            add(rc[x],-b);
            del(x);
        }
    }
}
阅读更多

没有更多推荐了,返回首页