splay
终于完成了!这个splay板子非常牛逼…本蒟蒻在做的时候犯了很多错误,调了三天。
数据结构题,当然要面向代码,现在将比较难懂的函数梳理一下。如果您只想看完整代码,请跳到最下面。
要维护的东西
a:原来的序列
b:插入操作中的序列
v:节点的值
siz:以该节点为根的子树大小
son:该节点的左右儿子
f:该节点的父亲
st:存放可以使用的数组下标的一个栈(即将删去了的节点原来的标记存入其中)
same:区间赋值的懒惰标记
rev:旋转的懒惰标记
sum:子树的权值和
mx:子树中的最大子序列和
ls:包含子树所代表的序列左端点的最大子序列(如果左端点是负值,则为0)
rs:包含子树所代表的序列右端点的最大子序列(如果右端点是负值,则为0)
这题不需要特别注意的基本操作都可以看本蒟蒻的另一篇博客:点我起飞
pushup
void up(int x) {
int l=son[x][0],r=son[x][1];
sum[x]=sum[l]+sum[r]+v[x],siz[x]=siz[l]+siz[r]+1;
ls[x]=max(ls[l],sum[l]+v[x]+ls[r]);
rs[x]=max(rs[r],sum[r]+v[x]+rs[l]);
mx[x]=max(max(mx[l],mx[r]),rs[l]+ls[r]+v[x]);
}
…好像看代码能懂吧?就不详细讲解了。
pushdown
别的地方都没有什么问题,就是本蒟蒻在打rev标记的时候出过两个错误,和大家分享一下:
错误1:打了rev标记后就直接跑路了,然后在处理到有rev标记的点的时候再将其左右子树反过来。这样会导致打完rev标记后,该节点祖先的ls和rs出错从而导致mx出错。
错误2:在交换完左右子树后,pushup了一下。这是不对的,因为交换完左右子树后只是部分翻转了,而事实上我们应该把这颗子树完全翻转,所以不能pushup而是要交换ls和rs
void pd(int x) {
int l=son[x][0],r=son[x][1],t=same[x];
if(same[x]!=inf) {
if(l) {//注意防止处理0
same[l]=v[l]=t,sum[l]=siz[l]*t;
if(t>=0) ls[l]=rs[l]=mx[l]=sum[l];
else ls[l]=rs[l]=0,mx[l]=t;
}
if(r) {
same[r]=v[r]=t,sum[r]=siz[r]*t;
if(t>=0) ls[r]=rs[r]=mx[r]=sum[r];
else ls[r]=rs[r]=0,mx[r]=t;
}
same[x]=inf;
}
if(rev[x]) {
if(l) rev[l]^=1,swap(son[l][0],son[l][1]),swap(ls[l],rs[l]);
if(r) rev[r]^=1,swap(son[r][0],son[r][1]),swap(ls[r],rs[r]);
rev[x]=0;
}
}
build和newjd
没错…就是这么暴力的函数名
int newjd() {//寻找可用的新节点编号
++trs;//splay的大小
if(top) return st[top--];
else return trs;
}
int build(int l,int r,int fnum,int fa,int *a) {
if(l>r) return 0;
int t=newjd(),mid=(l+r)/2;
if(l!=r) build(l,mid-1,0,t,a),build(mid+1,r,1,t,a);
v[t]=a[mid],son[fa][fnum]=t,same[t]=inf,rev[t]=0,f[t]=fa,up(t);
return t;
}
由于空间限制,我们需要一个st栈,储存所有删除操作中被清理出来的空间,然后优先使用这些空间。newjd就是一个用来寻找这样的空间的函数。
因而我们一定要注意编号和a数组编号的不同为build函数造成的影响。
do系列
即进行各种操作用的函数。
我们需要两个哨兵节点,一个放在最前面,一个放在最后面,防止把这棵树删干净。
然后将需要操作的序列左端点前面的那个节点splay到根,右端点后面的那个节点splay到根的右节点,这样根的右节点的左节点那棵子树就是我们要操作的序列,直接对其操作即可。
完整代码
#include<bits/stdc++.h>
using namespace std;
int read() {
int q=0,w=1;char ch=' ';
while(ch!='-'&&(ch<'0'||ch>'9')) ch=getchar();
if(ch=='-') w=-1,ch=getchar();
while(ch>='0'&&ch<='9') q=q*10+ch-'0',ch=getchar();
return w*q;
}
const int N=500005,inf=100000000;
int a[N],b[N],v[N],siz[N],son[N][2],f[N];
int st[N*10],same[N],rev[N],sum[N],mx[N],ls[N],rs[N];
//st:放置可用节点的栈,same:set标记,rev:翻转标记,sum:和
//mx:最大子列,ls:左边最大,rs:右边最大
int n,m,top,rot,trs;
void up(int x) {
int l=son[x][0],r=son[x][1];
sum[x]=sum[l]+sum[r]+v[x],siz[x]=siz[l]+siz[r]+1;
ls[x]=max(ls[l],sum[l]+v[x]+ls[r]);
rs[x]=max(rs[r],sum[r]+v[x]+rs[l]);
mx[x]=max(max(mx[l],mx[r]),rs[l]+ls[r]+v[x]);
}
void pd(int x) {
int l=son[x][0],r=son[x][1],t=same[x];
if(same[x]!=inf) {
if(l) {
same[l]=v[l]=t,sum[l]=siz[l]*t;
if(t>=0) ls[l]=rs[l]=mx[l]=sum[l];
else ls[l]=rs[l]=0,mx[l]=t;
}
if(r) {
same[r]=v[r]=t,sum[r]=siz[r]*t;
if(t>=0) ls[r]=rs[r]=mx[r]=sum[r];
else ls[r]=rs[r]=0,mx[r]=t;
}
same[x]=inf;
}
if(rev[x]) {
if(l) rev[l]^=1,swap(son[l][0],son[l][1]),swap(ls[l],rs[l]);
if(r) rev[r]^=1,swap(son[r][0],son[r][1]),swap(ls[r],rs[r]);
rev[x]=0;
}
}
int is(int x) {return son[f[x]][1]==x;}
void spin(int x,int &mb) {
int fa=f[x],g=f[fa],t=is(x);
if(mb==fa) mb=x;
else son[g][is(fa)]=x;
f[x]=g,f[fa]=x,f[son[x][t^1]]=fa;
son[fa][t]=son[x][t^1],son[x][t^1]=fa;
up(fa),up(x);
}
void splay(int x,int &mb) {
while(x!=mb) {
int fa=f[x],g=f[fa];
if(fa!=mb) {
if(is(fa)^is(x)) spin(x,mb);
else spin(fa,mb);
}
spin(x,mb);
}
}
int find(int x,int kth) {
if(same[x]!=inf||rev[x]) pd(x);
if(kth==siz[son[x][0]]+1) return x;
else if(kth<=siz[son[x][0]]) return find(son[x][0],kth);
else return find(son[x][1],kth-siz[son[x][0]]-1);
}
int newjd() {//寻找可用的新节点编号
++trs;
if(top) return st[top--];
else return trs;
}
int build(int l,int r,int fnum,int fa,int *a) {
if(l>r) return 0;
int t=newjd(),mid=(l+r)/2;
if(l!=r) build(l,mid-1,0,t,a),build(mid+1,r,1,t,a);
v[t]=a[mid],son[fa][fnum]=t,same[t]=inf,rev[t]=0,f[t]=fa,up(t);
return t;
}
void do1() {
int k,num,x,y;
k=read(),num=read();
for(int i=1;i<=num;++i) b[i]=read();
x=find(rot,k+1),y=find(rot,k+2);
splay(x,rot),splay(y,son[rot][1]);
son[y][0]=build(1,num,0,y,b);
up(y),up(rot);
}
void del(int x) {
if(!x) return;
v[x]=siz[x]=f[x]=rev[x]=sum[x]=0,same[x]=inf,mx[x]=-inf;
ls[x]=rs[x]=-inf,st[++top]=x;
del(son[x][0]),del(son[x][1]);son[x][0]=son[x][1]=0;
}
void do2() {
int l=read(),r=read()+l-1,x,y;
x=find(rot,l),y=find(rot,r+2);
splay(x,rot),splay(y,son[rot][1]);
trs-=siz[son[y][0]],del(son[y][0]),son[y][0]=0;
up(y),up(rot);
}
void do3() {
int l=read(),r=read()+l-1,x,y,kl,tmp=read();
x=find(rot,l),y=find(rot,r+2);
splay(x,rot),splay(y,son[rot][1]),kl=son[y][0];
v[kl]=same[kl]=tmp,sum[kl]=siz[kl]*tmp;
if(tmp>=0) mx[kl]=ls[kl]=rs[kl]=sum[kl];
else mx[kl]=tmp,ls[kl]=rs[kl]=0;
up(y),up(rot);
}
void do4() {
int l=read(),r=read()+l-1,x,y,kl;
x=find(rot,l),y=find(rot,r+2);
splay(x,rot),splay(y,son[rot][1]),kl=son[y][0];
rev[kl]^=1,swap(son[kl][0],son[kl][1]),swap(ls[kl],rs[kl]);
}
int do5() {
int l=read(),r=read()+l-1,x,y;
x=find(rot,l),y=find(rot,r+2);
splay(x,rot),splay(y,son[rot][1]);
return sum[son[y][0]];
}
int do6() {
int x=find(rot,1),y=find(rot,trs);
splay(x,rot),splay(y,son[rot][1]);
return mx[son[y][0]];
}
int main()
{
char ch[20];
n=read(),m=read();
mx[0]=-inf,a[1]=-inf,a[n+2]=inf;
for(int i=2;i<=n+1;++i) a[i]=read();
rot=build(1,n+2,0,0,a);
for(int i=1;i<=m;++i) {
scanf("%s",ch);
if(ch[0]=='I') do1();
else if(ch[0]=='D') do2();
else if(ch[0]=='M'&&ch[2]=='K') do3();
else if(ch[0]=='R') do4();
else if(ch[0]=='G') printf("%d\n",do5());
else if(ch[0]=='M'&&ch[2]=='X') printf("%d\n",do6());
}
return 0;
}
无旋treap
和上面那货差不多,按照无旋treap的基本操作来就可以了。
splay写了3天,无旋treap只写了两个小时,看来,treap这东西虽然慢一些,但是还是吼啊!
#include<bits/stdc++.h>
using namespace std;
int read() {
int q=0,w=1;char ch=' ';
while(ch!='-'&&(ch<'0'||ch>'9')) ch=getchar();
if(ch=='-') w=-1,ch=getchar();
while(ch>='0'&&ch<='9') q=q*10+ch-'0',ch=getchar();
return q*w;
}
#define mkp make_pair
#define pr pair<int,int>
const int N=500005,inf=0x3f3f3f3f;
int son[N][2],sz[N],pos[N],v[N],a[N],s[N];
int ls[N],rs[N],mx[N],rev[N],sum[N],laz[N];//s!
int n,m,top,rt,sss;
void up(int x) {
int l=son[x][0],r=son[x][1];
sz[x]=sz[l]+sz[r]+1,sum[x]=sum[l]+sum[r]+v[x];
ls[x]=max(ls[l],sum[l]+v[x]+ls[r]);
rs[x]=max(rs[r],sum[r]+v[x]+rs[l]);
mx[x]=max(rs[l]+ls[r]+v[x],max(mx[l],mx[r]));
}
void pd(int x) {
int l=son[x][0],r=son[x][1],t=laz[x];
if(laz[x]!=inf) {
if(l) {
laz[l]=v[l]=t,sum[l]=sz[l]*t;
if(t>=0) ls[l]=rs[l]=mx[l]=sum[l];
else ls[l]=rs[l]=0,mx[l]=t;
}
if(r) {
laz[r]=v[r]=t,sum[r]=sz[r]*t;
if(t>=0) ls[r]=rs[r]=mx[r]=sum[r];
else ls[r]=rs[r]=0,mx[r]=t;
}
laz[x]=inf;
}
if(rev[x]) {
if(l) rev[l]^=1,swap(son[l][0],son[l][1]),swap(ls[l],rs[l]);
if(r) rev[r]^=1,swap(son[r][0],son[r][1]),swap(ls[r],rs[r]);
rev[x]=0;
}
}
pr split(int x,int num) {
if(!x) return mkp(0,0);
pd(x);int l=son[x][0],r=son[x][1];
if(sz[son[x][0]]==num) {son[x][0]=0,up(x);return mkp(l,x);}
if(sz[son[x][0]]+1==num) {son[x][1]=0,up(x);return mkp(x,r);}
if(sz[son[x][0]]>num) {
pr tmp=split(son[x][0],num);
son[x][0]=tmp.second,up(x);
return mkp(tmp.first,x);
}
else {
pr tmp=split(son[x][1],num-sz[son[x][0]]-1);
son[x][1]=tmp.first,up(x);
return mkp(x,tmp.second);
}
}
int merge(int a,int b) {
if(!a) return b;
if(!b) return a;
pd(a),pd(b);
if(pos[a]<pos[b]) {son[a][1]=merge(son[a][1],b),up(a);return a;}
else {son[b][0]=merge(a,son[b][0]),up(b);return b;}
}
int newjd(int num) {
int x;
if(!top) x=++sss;
else x=s[top--];
sz[x]=1,pos[x]=rand(),v[x]=sum[x]=mx[x]=num;
son[x][0]=son[x][1]=0,laz[x]=inf,ls[x]=rs[x]=max(0,num);
return x;
}
int build(int l,int r) {
if(l==r) {int x=newjd(a[l]);return x;}
int mid=(l+r)>>1;
merge(build(l,mid),build(mid+1,r));
}
void del(int x) {
if(son[x][0]) del(son[x][0]);
if(son[x][1]) del(son[x][1]);
if(top<N-5) s[++top]=x;
}
int main()
{
char ch[10];int x,tt,num;
mx[0]=-inf;
n=read(),m=read(),srand(m);
for(int i=1;i<=n;++i) a[i]=read();
rt=build(1,n);
while(m--) {
scanf("%s",ch);
if(ch[0]=='I') {
x=read(),tt=read();
pr tmp=split(rt,x);
for(int i=1;i<=tt;++i) a[i]=read();
rt=merge(tmp.first,merge(build(1,tt),tmp.second));
}
else if(ch[0]=='D') {
x=read(),tt=read();
pr t1=split(rt,x+tt-1),t2=split(t1.first,x-1);
rt=merge(t2.first,t1.second),del(t2.second);
}
else if(ch[0]=='M'&&ch[2]=='K') {
x=read(),tt=read(),num=read();
pr t1=split(rt,x+tt-1),t2=split(t1.first,x-1);
int kl=t2.second;
laz[kl]=v[kl]=num,sum[kl]=num*sz[kl];
if(num>=0) ls[kl]=rs[kl]=mx[kl]=sum[kl];
else ls[kl]=rs[kl]=0,mx[kl]=num;
rt=merge(t2.first,merge(t2.second,t1.second));
}
else if(ch[0]=='R') {
x=read(),tt=read();
pr t1=split(rt,x+tt-1),t2=split(t1.first,x-1);
int kl=t2.second;
rev[kl]^=1,swap(son[kl][0],son[kl][1]),swap(ls[kl],rs[kl]);
rt=merge(t2.first,merge(t2.second,t1.second));
}
else if(ch[0]=='G') {
x=read(),tt=read();
pr t1=split(rt,x+tt-1),t2=split(t1.first,x-1);
printf("%d\n",sum[t2.second]);
rt=merge(t2.first,merge(t2.second,t1.second));
}
else if(ch[0]=='M') printf("%d\n",mx[rt]);
}
return 0;
}