[bzoj4399]魔法少女LJJ
我为什么要作死写splay......
前面的操作用平衡树(或者写线段树合并)就行了,主要是注意求积的操作是一定会算爆的所以我们可以考虑用其他的方式维护,比如直接求对数函数。剩下的就是码了。
写的奇丑不要在意
- 代码
#include<bits/stdc++.h>
using namespace std;
const int N=1e6+1;
const int MX=1e9;
int root[N],prt[N],sum[N];
struct Splay{
int sz,sum,ch[2],fa;
long double val,myval;
int numb;
}t[N];
int null=0,cnt,tot;
int id[N];
void newnode(int _val){
t[++cnt].sum=_val;
t[cnt].sz=t[cnt].numb=1;
t[cnt].myval=t[cnt].val=log(_val);
++tot;
prt[tot]=tot;
root[tot]=cnt;
}
inline int find(int x){
return prt[x]==x?prt[x]:prt[x]=find(prt[x]);
}
bool son(int x){return t[t[x].fa].ch[1]==x;}
bool isroot(int x){return t[t[x].fa].ch[0]!=x&&t[t[x].fa].ch[1]!=x;}
inline void pushup(int x){
t[x].sz=t[x].numb;
t[x].val=t[x].myval;
if(t[x].ch[0])t[x].sz+=t[t[x].ch[0]].sz,t[x].val+=t[t[x].ch[0]].val;
if(t[x].ch[1])t[x].sz+=t[t[x].ch[1]].sz,t[x].val+=t[t[x].ch[1]].val;
}
inline void rotate(int x){
int f=t[x].fa;int gf=t[f].fa;
bool a=son(x),b=son(x)^1;
if(!isroot(f))t[gf].ch[son(f)]=x;
t[x].fa=gf;
t[f].ch[a]=t[x].ch[b];t[t[x].ch[b]].fa=f;
t[x].ch[b]=f;t[f].fa=x;
pushup(f);pushup(x);
//"pushup
}
inline void splay(int x){
while(!isroot(x)){
int f=t[x].fa;
if(!isroot(f)){
if(son(x)^son(f))rotate(x);
else rotate(f);
}
rotate(x);
}
}
inline void ins(int rt,int id){
if(t[rt].sum==t[id].sum){
t[rt].numb+=t[id].numb;
t[rt].myval=t[rt].numb*log(t[rt].sum);
}
if(t[rt].sum<t[id].sum){
if(t[rt].ch[1])ins(t[rt].ch[1],id);
else t[rt].ch[1]=id,t[id].fa=rt;
}
if(t[rt].sum>t[id].sum){
if(t[rt].ch[0])ins(t[rt].ch[0],id);
else t[rt].ch[0]=id,t[id].fa=rt;
}
pushup(rt);
}
inline void merge(int rt1,int rt2){
if(t[rt1].ch[0])merge(t[rt1].ch[0],rt2);
if(t[rt1].ch[1])merge(t[rt1].ch[1],rt2);
t[rt1].fa=t[rt1].ch[0]=t[rt1].ch[1]=null;
pushup(rt1);
ins(rt2,rt1);
}
int findmx(int u,int sum){//找第一个大于等于
if(!u)return null;
if(t[u].sum>=sum){
int ret=findmx(t[u].ch[0],sum);
if(ret)return ret;
else return u;
}else {
return findmx(t[u].ch[1],sum);
}
}
int findmn(int u,int sum){//找第一个小于等于
if(!u)return null;
if(t[u].sum<=sum){
int ret=findmn(t[u].ch[1],sum);
if(ret)return ret;
else return u;
}else {
return findmn(t[u].ch[0],sum);
}
}
inline int Kth(int u,int rank){
int pre=t[t[u].ch[0]].sz,nxt=t[t[u].ch[0]].sz+t[u].numb;
if(rank<=pre)return Kth(t[u].ch[0],rank);
if(rank>pre&&rank<=nxt)return t[u].sum;
if(rank>nxt)return Kth(t[u].ch[1],rank-nxt);
return 0;
}
int T;
int main()
{
//freopen("3.in","r",stdin);
scanf("%d",&T);
while(T--){
int opt,x,y;
scanf("%d",&opt);
if(opt==1){
scanf("%d",&x);
newnode(x);
}
if(opt==2){
scanf("%d%d",&x,&y);
int X=find(x),Y=find(y);
if(X==Y)continue;
if(t[root[X]].sz<t[root[Y]].sz){
prt[X]=Y;
merge(root[X],root[Y]);
}else{
prt[Y]=X;
merge(root[Y],root[X]);
}
}
if(opt==4){
scanf("%d%d",&x,&y);
int X=find(x);
int rt=root[X];
int tmp=findmx(rt,y);
if(!tmp)continue;
root[X]=tmp;
splay(tmp);
t[tmp].sum=y;
t[tmp].numb+=t[t[tmp].ch[1]].sz;
t[tmp].ch[1]=null;
t[tmp].myval=(long double)t[tmp].numb*log(t[tmp].sum);
pushup(tmp);
}
if(opt==3){
scanf("%d%d",&x,&y);
int X=find(x);
int rt=root[X];
int tmp=findmn(rt,y);
if(!tmp)continue;
root[X]=tmp;
splay(tmp);
t[tmp].sum=y;
t[tmp].numb+=t[t[tmp].ch[0]].sz;
t[tmp].ch[0]=null;
t[tmp].myval=(long double)t[tmp].numb*log(t[tmp].sum);
pushup(tmp);
}
if(opt==5){
scanf("%d%d",&x,&y);
int RT=root[find(x)];
printf("%d\n",Kth(RT,y));
}
if(opt==6){
scanf("%d%d",&x,&y);
int rtx=root[find(x)],rty=root[find(y)];
printf("%d\n",t[rtx].val>t[rty].val);
}
if(opt==7){
scanf("%d",&x);
int rt=root[find(x)];
printf("%d\n",t[rt].sz);
}
}
}