第十二届蓝桥杯国赛 冰山 题解 splay树
题目
暂时搞不到原题目先空在这里。
思路
个人觉得模拟会超时超内存(会有很多个1,甚至多于998244353),所以用平衡树来操作,这里就用splay数,时间复杂度是O((n+3m)*log(n+3m)),题目总共三种操作:
1、每天一开始为所有冰山增加体积,(设每座冰山的体积为vi)对于每一座冰山,若vi>k则添加vi-k座体积为1的冰山,最后让vi=k;
2、每天一开始为所有冰山减小体积,所有体积小于等于0的冰山全部消失;
3、每天快结束的时候添加一座冰山。
首先我们建立一颗splay树,就是最模板的那种,只需要add操作,每次add的那个结点要splay到根,结点上的权值就是冰山的体积,另外每个结点还需要记一个cnt,用来标记相同体积冰山的个数,例如有这么多的冰山:{1,1,1,2,3},则只需要三个结点就行了,每个结点上的权值v1=1,v2=2,v3=3,每个结点上的cnt:cnt1=3,cnt2=1,cnt3=1。
我们倒着讲。
显然第三种操作非常轻松,只要add就行了。
而第二种操作也挺简单的,假设题目要求所有的冰山体积全部减少x,那么只需要把体积为x的结点旋转到根,此时根结点的左子树上所有的结点的权值必然小于x,题目要求这些冰山要全部消失,所以我们只需要删除根结点的左子结点就行了,并且把根结点的cnt改为0,这样就能满足题目的要求,然后我们还需要在根节点打个-x的懒标记,这样就能保证答案的正确。但此时有一个问题,要是不存在权值为x的结点怎么办?对于这个问题,我一点也不想思考,直接add一个权值为x的结点然后再进行上述操作就行了,也就多开个长度为m的空间而已。
最后就是第一种操作,假设题目要求所有的冰山体积全部增加x,根据题意,所有体积大于等于k的冰山全都会变成体积为k的冰山,而多出来的体积全部分裂成体积为1的小冰山。我们先不管那些体积为1的冰山,先处理前半部分。我们先把权值为k-x的结点旋转到根,此时根节的右子树上所有结点的权值都一定大于k-x,这句话写作vi>k-x,x移到式子左边就是vi+x>k,那么就说明根结点的右子树上所有的结点在执行“所有的冰山体积全部增加x”的操作之后,他们的权值全部都会大于k,那么我们只需要把根结点右子树上冰山的个数(一棵树的冰山个数,等于这棵树上所有结点的cnt总和)加到根节点的cnt上面就行了,这样就完成了前半部分的操作。后半部分也很简单,先计算出体积多出了多少,假设根结点的右子树上冰山的个数为sizer,根结点的右子树上所有冰山的体积之和为sumr,那么多出来的体积就是Δv=sumr+sizer*x-sizer*k,Δv的大小就是需要添加的冰山个数,不过不能现在添加,还有操作没有做完,标记了Δv后,把根结点的右儿子删除,根结点打上+x的懒标记,最后再添加Δv个体积为1的冰山。我们不可能一个一个添加,因为Δv有可能会很大,所以还需要稍微改造一下splay的add操作,让它可以add指定数量的冰山,如果树中本来就有权值为1的结点,那么我们希望新添加的体积为1的冰山要跟这个结点合并,保证同一个权值至多只有一个结点。
所有的操作都完成了,如果代码没有问题,那么只需要每次输出根结点维护的体积总和就可以了。
复杂度分析
第一种操作是最费空间和时间的,所以最差情况就是所有操作都是一操作。首先是空间,一开始就有n座冰山,空间为n,对于m天,每天都会加入一座冰山,此时空间为n+m,若每天所有冰山的体积都增加,我们在执行“把权值为k-x的结点旋转到根”这步操作的时候,实际上是直接添加体积为k-x,cnt为0的冰山,这里至多会有一个新的结点,此时空间为n+2m,再之后要添加体积为1的冰山,又至多会有一个新的节点,所以最后的空间至少要开n+3m,应该不会超。splay的时间复杂度我不会算,感觉大概差不多好像就是O((n+3m)*log(n+3m))。
代码
import java.io.*;
import java.util.Scanner;
import java.util.TreeMap;
public class Main
{
long p=998244353,k;
int n,m,len;
long []val,sum,cnt,siz,laz;
int [][]ch;
int []f;
int root=0;
void del(int a)
{
ch[a][0]=0;
ch[a][1]=0;
int fa=f[a];
if(fa!=0){
ch[fa][ch[fa][0]==a?0:1]=0;
}
f[a]=0;
}
void update(int a){
siz[a]=(siz[ch[a][0]]+siz[ch[a][1]]+cnt[a])%p;
sum[a]=(sum[ch[a][0]]+sum[ch[a][1]]+val[a]*cnt[a]%p)%p;
}
void connect(int a,int fa,int s){
if(fa!=0)ch[fa][s]=a;
if(a!=0)f[a]=fa;
}
void rotate(int a){
int fa=f[a],ffa=f[fa];
int s=ch[fa][0]==a?0:1;
connect(ch[a][s^1],fa,s);
connect(a,ffa,ch[ffa][0]==fa?0:1);
connect(fa,a,s^1);
update(fa);
update(a);
}
void splay(int a,int to){
for(int fa=f[a];fa!=to;rotate(a),fa=f[a])
{
int ffa=f[fa];
if(ffa!=to){
if((ch[ffa][0]==fa)==(ch[fa][0]==a))rotate(fa);
else rotate(a);
}
}
if(to==0)
root=a;
}
void pushdown(int a)
{
if(laz[a]!=0)
{
for(int i=0;i<2;i++)
if(ch[a][i]!=0)
{
val[ch[a][i]]+=laz[a];
sum[ch[a][i]]+=laz[a]*1l*siz[ch[a][i]]%p;
sum[ch[a][i]]%=p;
laz[ch[a][i]]+=laz[a];
}
laz[a]=0;
}
}
void add(long x,int c,int to)
{
if(root==0)
{
root=len;
val[len]=x;
cnt[len]+=c;
update(len);
root=len;
len++;
return;
}
int a=root;
while(true){
pushdown(a);
if(val[a]==x){
splay(a,0);
cnt[a]+=c;
update(a);
root=a;
return;
}
if(val[a]>x){
if(ch[a][0]!=0)a=ch[a][0];
else{
val[len]=x;
cnt[len]+=c;
connect(len,a,0);
splay(len,to);
len++;
return;
}
}else{
if(ch[a][1]!=0)a=ch[a][1];
else{
val[len]=x;
cnt[len]+=c;
connect(len,a,1);
splay(len,to);
len++;
return;
}
}
}
}
void show(){
System.out.println("root "+root );
System.out.println("i\tl\tr\tf\tval\tcnt\tsiz\tsum\tlaz\t");
for(int i=1;i<len;i++)
{
System.out.printf("%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\t%d\n",i,ch[i][0],ch[i][1],f[i],val[i],cnt[i],siz[i],sum[i],laz[i]);
}
System.out.println("---------------------------------------------");
}
void run()throws IOException{
n=in();
m=in();
k=in();
val=new long[n+(m*3)];
sum=new long[n+(m*3)];
cnt=new long[n+(m*3)];
ch=new int[n+(m*3)][2];
f=new int[n+(m*3)];
siz=new long[n+(m*3)];
laz=new long[n+(m*3)];
len=1;
for(int i=0;i<n;i++)
{
add(in(),1,0);
}
for(int i=0;i<m;i++)
{
int x=in();
int y=in();
if(x>0){
add(k-x,0,0);
int rs=ch[root][1];
int tmp=0;
if(rs!=0){
cnt[root]+=siz[rs];
tmp+=sum[rs]+x*siz[rs]-k*siz[rs];
tmp%=p;
tmp+=p;
tmp%=p;
del(rs);
update(root);
}
sum[root]+=x*siz[root]%p;
val[root]+=x;
laz[root]+=x;
add(1,tmp,0);
}else if(x<0)
{
add(-x,0,0);
int ls=ch[root][0];
if(ls!=0)
{
del(ls);
}
cnt[root]=0;
update(root);
sum[root]+=x*siz[root]%p;
val[root]+=x;
laz[root]+=x;
}
add(y,1,0);
// show();
long ans=sum[root];
ans+=p;
ans%=p;
out.println(ans);
}
out.flush();
}
public static void main(String[]args)throws IOException
{
new Main().run();
}
StreamTokenizer in=new StreamTokenizer(new BufferedReader(new InputStreamReader(System.in)));
PrintWriter out=new PrintWriter(new BufferedWriter(new OutputStreamWriter(System.out)));
String ins() throws IOException{
in.nextToken();
return in.sval;
}
int in() throws IOException{
in.nextToken();
return (int)in.nval;
}
}
要是wa了就尴尬了。就算代码错了思路应该是对的。