用splay来维护数列的一道很好的模版练习题
有些要注意的地方:
- 不用指针来实现splay的话要注意内存回收,不然就会无限RE+TLE
- 凡是有对序列进行修改的操作都要up一下
- reverse的时候要注意交换左起最大连续和跟右起最大连续和
代码:
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<string>
#include<iomanip>
#include<vector>
#include<set>
#include<map>
#include<queue>
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
#define rep(i,k,n) for(int i=(k);i<=(n);i++)
#define rep0(i,n) for(int i=0;i<(n);i++)
#define red(i,k,n) for(int i=(k);i>=(n);i--)
#define sqr(x) ((x)*(x))
#define clr(x,y) memset((x),(y),sizeof(x))
#define pb push_back
#define mod 1000000007
#define L ch[x][0]
#define R ch[x][1]
#define KT (ch[ch[rt][1]][0])
const int inf=-2000;
const int maxn=510000;
int numa[maxn],numb[maxn];
struct SplayTree
{
int sz[maxn];
int ch[maxn][2];
int pre[maxn];
int rt,top;
int pool[maxn],ptop;
int flip[maxn];
int same[maxn];
int val[maxn];
int sum[maxn];
int maxs[maxn],maxl[maxn],maxr[maxn];
inline void down(int x)
{
if(flip[x])
{
flip[L]^=1;
flip[R]^=1;
swap(L,R);
swap(maxl[L],maxr[L]);
swap(maxl[R],maxr[R]);
flip[x]=0;
}
if(same[x]!=inf)
{
rep0(i,2)if(ch[x][i])
{
same[ch[x][i]]=same[x];
val[ch[x][i]]=same[x];
sum[ch[x][i]]=same[x]*sz[ch[x][i]];
if(same[x]<=0)maxs[ch[x][i]]=maxl[ch[x][i]]=maxr[ch[x][i]]=same[x];
else maxs[ch[x][i]]=maxl[ch[x][i]]=maxr[ch[x][i]]=sum[ch[x][i]];
}
same[x]=inf;
}
}
inline void up(int x)
{
sz[x]=1+sz[L]+sz[R];
sum[x]=sum[L]+val[x]+sum[R];
maxl[x]=max(max(maxl[L],sum[L]+val[x]),sum[L]+val[x]+maxl[R]);
maxr[x]=max(max(maxr[R],sum[R]+val[x]),sum[R]+val[x]+maxr[L]);
maxs[x]=max(max(maxs[L],maxs[R]),max(maxr[L],maxl[R])+val[x]);
maxs[x]=max(max(maxs[x],val[x]),maxr[L]+val[x]+maxl[R]);
}
inline void Rotate(int x,int f)
{
int y=pre[x];
down(y);
down(x);
ch[y][!f]=ch[x][f];
pre[ch[x][f]]=y;
pre[x]=pre[y];
if(pre[x])ch[pre[y]][ch[pre[y]][1]==y]=x;
ch[x][f]=y;
pre[y]=x;
up(y);
}
void Splay(int x,int goal)
{
down(x);
while(pre[x]!=goal)
{
down(pre[pre[x]]);down(pre[x]);down(x);
if(pre[pre[x]]==goal)Rotate(x,ch[pre[x]][0]==x);
else
{
int y=pre[x],z=pre[y];
int f=(ch[z][0]==y);
if(ch[y][f]==x)Rotate(x,!f),Rotate(x,f);
else Rotate(y,f),Rotate(x,f);
}
}
up(x);
if(goal==0)rt=x;
}
inline void RTO(int k,int goal)
{
int x=rt;
down(x);
while(sz[L]+1!=k)
{
if(k<sz[L]+1)x=L;
else
{
k-=(sz[L]+1);
x=R;
}
down(x);
}
Splay(x,goal);
}
void Newnode(int &x,int c,int f)
{
if(ptop)
{
x=pool[ptop--];
}
else
{
x=++top;
}
flip[x]=0;same[x]=inf;
L=R=0;pre[x]=f;
sz[x]=1;
val[x]=sum[x]=maxs[x]=maxl[x]=maxr[x]=c;
}
void build(int &x,int l,int r,int f,int num[])
{
if(l>r)return ;
int m=l+r>>1;
Newnode(x,num[m],f);
build(L,l,m-1,x,num);
build(R,m+1,r,x,num);
pre[x]=f;
up(x);
}
void init(int n)
{
ch[0][0]=ch[0][1]=pre[0]=sz[0]=0;
rt=top=0;flip[0]=val[0]=0;
ptop=0;
same[0]=inf;
sum[0]=0;
maxs[0]=maxl[0]=maxr[0]=inf;
Newnode(rt,inf,0);
Newnode(ch[rt][1],inf,rt);
sz[rt]=2;
build(KT,1,n,ch[rt][1],numa);
up(ch[rt][1]);up(rt);
}
void SAME(int pos,int len,int c)
{
if(len<=0)return;
int a=pos,b=pos+len-1;
RTO(a,0);
RTO(b+2,rt);
same[KT]=val[KT]=c;
sum[KT]=c*sz[KT];
if(c<=0)maxs[KT]=maxl[KT]=maxr[KT]=c;
else maxs[KT]=maxl[KT]=maxr[KT]=sum[KT];
up(ch[rt][1]);up(rt);
}
void REVERSE(int pos,int len)
{
if(len<=0)return;
int a=pos,b=pos+len-1;
RTO(a,0);
RTO(b+2,rt);
flip[KT]^=1;
swap(maxl[KT],maxr[KT]);
up(ch[rt][1]);up(rt);
}
void INSERT(int pos,int len,int c[])
{
if(len<=0)return;
RTO(pos+1,0);
RTO(pos+2,rt);
build(KT,1,len,ch[rt][1],c);
up(ch[rt][1]);up(rt);
}
void collect(int x)
{
if(x==0)return;
pool[++ptop]=x;
collect(L);collect(R);
}
void DELETE(int pos,int len)
{
if(len<=0)return;
int a=pos,b=pos+len-1;
RTO(a,0);
RTO(b+2,rt);
collect(KT);
KT=0;
up(ch[rt][1]);up(rt);
}
int GETSUM(int pos,int len)
{
if(len<=0)return 0;
int a=pos,b=pos+len-1;
RTO(a,0);
RTO(b+2,rt);
return sum[KT];
}
int MAXSUM()
{
if(sz[rt]==2)return 0;
return maxs[rt];
}
// void debug(int x)
// {
// down(x);
// if(L)debug(L);
// printf("%d ",val[x]);
// if(R)debug(R);
// }
}spt;
int main()
{
int n,m,tmp,pos,len;
char str[20];
scanf("%d%d",&n,&m);
rep(i,1,n)scanf("%d",&numa[i]);
spt.init(n);
rep(i,1,m)
{
//spt.debug(spt.rt);
scanf("%s",str);
if(str[0]=='I')
{
scanf("%d%d",&pos,&len);
rep(j,1,len)scanf("%d",&numb[j]);
spt.INSERT(pos,len,numb);
}
else if(str[0]=='D')
{
scanf("%d%d",&pos,&len);
spt.DELETE(pos,len);
}
else if(str[0]=='R')
{
scanf("%d%d",&pos,&len);
spt.REVERSE(pos,len);
}
else if(str[0]=='G')
{
scanf("%d%d",&pos,&len);
printf("%d\n",spt.GETSUM(pos,len));
}
else
{
if(str[2]=='X')
{
printf("%d\n",spt.MAXSUM());
}
else if(str[2]=='K')
{
scanf("%d%d%d",&pos,&len,&tmp);
spt.SAME(pos,len,tmp);
}
}
}
return 0;
}