[BZOJ3224]普通平衡树

模板题嘛。。

splay:

#include<iostream>
#include<cstdio>
#include<memory.h>
#define maxn 500005
using namespace std;
int i,n,opt,x,nd=0,ans,root=0,c[maxn][2],size[maxn],pre[maxn],same[maxn],num[maxn];
void update(int x){size[x]=size[c[x][0]]+size[c[x][1]]+same[x];}
void newnode(int &x,int fa,int k)
{
	x=++nd;pre[x]=fa;
	size[x]=same[x]=1;
	c[x][0]=c[x][1]=0;
	num[x]=k;
}
void rot(int x,int kind)
{
	int y=pre[x],z=pre[y];
	c[y][!kind]=c[x][kind];pre[c[x][kind]]=y;
	c[x][kind]=y;pre[y]=x;
	pre[x]=z;if (z) c[z][c[z][1]==y]=x;
	update(y);update(x);
}
void splay(int x,int goal)
{
	int kind,y,z;
	while (pre[x]!=goal)
	{
		y=pre[x];
		if (pre[y]==goal) rot(x,c[y][0]==x);
		else
		{
			z=pre[y];kind=c[z][0]==y;
			if (c[y][kind]==x) rot(x,!kind);else rot(y,kind);
			rot(x,kind);
		}
	}
	if (goal==0) root=x;
}
void ins(int k)
{
	int get=root;
	if (!root) newnode(root,0,k);
	if (num[get]==k){same[get]++;size[get]++;return;}
	while (c[get][num[get]<k])
	{
		get=c[get][num[get]<k];
		if (num[get]==k){same[get]++;size[get]++;splay(get,0);return;}
	}
	newnode(c[get][num[get]<k],get,k);
	splay(c[get][num[get]<k],0);
}
void find(int x,int k)
{
	if (num[x]==k) {splay(x,0);return;}else find(c[x][num[x]<k],k);
}
int join(int s1,int s2)
{
	while (c[s1][1]) s1=c[s1][1];
	splay(s1,pre[s2]);
	c[s1][1]=s2;pre[s2]=s1;
	update(s1);return s1;
}
void del(int k)
{
	find(root,k);
	if (c[root][0]*c[root][1]==0) {root=c[root][0]+c[root][1];return;}
	root=join(c[root][0],c[root][1]);
}
int rank(int x,int k)
{
	if (num[x]==k) return size[c[x][0]]+1;
	if (num[x]>k) return rank(c[x][0],k);else return rank(c[x][1],k)+size[c[x][0]]+same[x];
}
int findkth(int x,int k)
{
	if (size[c[x][0]]>=k) return findkth(c[x][0],k);
	if (size[c[x][0]]+same[x]>=k) return num[x];
	return findkth(c[x][1],k-size[c[x][0]]-same[x]);
}
void getpre(int x,int k)
{
	if (!x) return;
	if (num[x]<k){ans=x;getpre(c[x][1],k);}else getpre(c[x][0],k);
}
void getpos(int x,int k)
{
	if (!x) return;
	if (num[x]>k){ans=x;getpre(c[x][0],k);}else getpre(c[x][1],k);
}
int main()
{
	scanf("%d",&n);
	for (i=1;i<=n;i++)
	{
		scanf("%d%d",&opt,&x);
		switch(opt)
		{
			case 1:ins(x);break;
			case 2:del(x);break;
			case 3:printf("%d\n",rank(root,x));break;
			case 4:printf("%d\n",findkth(root,x));break;
			case 5:getpre(root,x);printf("%d\n",num[ans]);break;
			case 6:getpos(root,x);printf("%d\n",num[ans]);break;
		}
	}
}

treap:

#include<iostream>
#include<cstdio>
#include<memory.h>
#include<cstdlib>
#include<time.h>
#define maxn 500005
using namespace std;
int i,n,opt,x,nd=0,ans,root=0,c[maxn][2],size[maxn],same[maxn],fix[maxn],num[maxn];
void update(int x){size[x]=size[c[x][0]]+size[c[x][1]]+same[x];}
void newnode(int &x,int k)
{
	x=++nd;c[x][0]=c[x][1]=0;
	size[x]=same[x]=1;
	fix[x]=rand();num[x]=k;
}
void rot(int &x,int kind)
{
	int t=c[x][!kind];
	c[x][!kind]=c[t][kind];c[t][kind]=x;
	update(x);update(t);x=t;
}
void ins(int &x,int k)
{
	if (!x) {newnode(x,k);return;}
	size[x]++;
	if (num[x]==k) {same[x]++;return;}
	ins(c[x][num[x]<k],k);
	if (fix[c[x][num[x]<k]]<fix[x]) rot(x,!(num[x]<k));
}
void del(int &x,int k)
{
	if (num[x]==k)
	{
		if (same[x]>1) {size[x]--,same[x]--;return;}
		if (c[x][0]*c[x][1]==0) {x=c[x][0]+c[x][1];return;}
		rot(x,fix[c[x][0]]<fix[c[x][1]]);del(x,k);
	}
	else size[x]--,del(c[x][num[x]<k],k);
}
int rank(int x,int k)
{
	if (num[x]==k) return size[c[x][0]]+1;
	if (num[x]>k) return rank(c[x][0],k);else return rank(c[x][1],k)+size[c[x][0]]+same[x];
}
int findkth(int x,int k)
{
	if (size[c[x][0]]>=k) return findkth(c[x][0],k);
	if (size[c[x][0]]+same[x]>=k) return num[x];
	return findkth(c[x][1],k-size[c[x][0]]-same[x]);
}
void getpre(int x,int k)
{
	if (!x) return;
	if (num[x]<k) {ans=x;getpre(c[x][1],k);} else getpre(c[x][0],k);
}
void getpos(int x,int k)
{
	if (!x) return;
	if (num[x]>k) {ans=x;getpos(c[x][0],k);} else getpos(c[x][1],k);
}
int main() 
{
	srand((unsigned)time(0));
	scanf("%d",&n);
	for (i=1;i<=n;i++)
	{
		scanf("%d%d",&opt,&x);
		switch(opt)
		{
			case 1:ins(root,x);break;
			case 2:del(root,x);break;
			case 3:printf("%d\n",rank(root,x));break;
			case 4:printf("%d\n",findkth(root,x));break;
			case 5:getpre(root,x);printf("%d\n",num[ans]);break;
			case 6:getpos(root,x);printf("%d\n",num[ans]);break;
		}
	}
}


展开阅读全文

没有更多推荐了,返回首页