普通平衡树 Splay

Splay 简介

Splay(伸展树),又叫做分裂树,是一种自调整形式的二叉查找树,满足二叉查找树的性质:一个节点左子树的所有节点的权值,均小于这个节点的权值。且其右子树所有节点的权值,均大于这个节点的权值。
因此Splay的中序遍历是一个递增序列。

Splay可以用来维护实链剖分(LCT)等,作为普通平衡树,它的优势在于不需要记录用于平衡树的冗余信息。

Splay维护一个有序集合,支持如下操作:

  1. 向集合中添加一个数
  2. 删除集合中的一个数
  3. 求出一个数的排名
  4. 根据排名求出这个数
  5. 查找一个数的前驱
  6. 查找一个数的后继

Splay原理以及实现

模板题

约定

为了代码简洁以及安全,我们用数组模拟Splay,并且做出规定如下性质:

  • 安全性:不在Splay上的节点,以及被删除的节点,其所有信息应该被清空。
  • 保证:我们保证函数不可能被非法调用,或者所有可能的非法调用是无害的,因此不需要在被调用的函数内部进行特判。
    例如:push_up(0)是无害的。
  • 代码重用:我们尽可能的保证代码重用
  • 节点从1开始编号,0号节点可能有多余的子孙/后代信息,但是其val,cnt,siz信息始终为0

或许每一个约定都并不是完全必要的。

节点:node

Splay上的一个节点(node)维护这样几个信息:

  • fa:这个节点的父亲编号,fa=0表示没有父亲
  • ch[0]:节点的左儿子编号,ch[0]的别名是l,若l=0表示没有左儿子
  • ch[1]:节点的右儿子编号,ch[1]的别名是r,若r=0表示没有右儿子
  • val:节点的权值
  • cnt:节点权值在集合中出现的次数
  • siz:以此节点为根的子树的大小
  • 成员函数set(v,c,s):用来初始化节点信息,使得val=v,cnt=c,siz=s,并且让fa=l=r=0。其中cs的默认值为1
const int N=;
struct node {
	int fa,ch[2],val,cnt,siz;
	int&l=ch[0],&r=ch[1];
	void set(int v,int c=1,int s=1) {
		fa=l=r=0;
		val=v;
		cnt=c;
		siz=s;
	}
} t[N+5];
int tot,root;

左右儿子函数(get)

函数原型:

bool get(int);

函数get(u)返回编号为u的节点是其父亲的左儿子(返回0)或者右儿子(返回1)。

函数定义:

bool get(int u) {
	return t[t[u].fa].r==u;
}

上传(push_up)

函数原型:

void push_up(int);

函数push_up(u)将编号为u节点用自己的两个儿子的信息更新自己的siz信息。当有儿子编号为0时不影响,因为我们保证0号节点的siz信息为0

函数定义:

void push_up(int u) {
	t[u].siz=t[t[u].l].siz+t[t[u].r].siz+t[u].cnt;
}

事实上push_up(0)也不影响0节点的siz,因为调用push_up(0)仅在pop函数中root=0时,但此时由于早已del0节点的左右儿子,因此0节点必然没有左右儿子的信息。

加入节点(add)

函数原型:

void add(int,int,bool);

函数add(fa,son,k)将编号为son的节点加入Splay,并且它是父亲fak侧儿子。

函数定义:

void add(int fa,int son,bool k) {
	t[t[son].fa=fa].ch[k]=son;
}

删除节点(del)

函数原型:

void del(int);

函数del(u)将编号为u的节点从Splay中删除,这需要操作它的父亲和左右儿子,并且将它的三个权值(val,cnt,siz)清空。

函数定义:

void del(int u) {
	t[t[u].l].fa=t[t[u].r].fa=t[t[u].fa].ch[get(u)]=0;
	t[u].set(0,0,0);
}

旋转(rotate)

Splay的单次操作复杂度并不是严格 O ( log ⁡ n ) O(\log n) O(logn)的,但是Splay依靠其伸展操作(splay)使得总复杂度为均摊 O ( n log ⁡ n ) O(n\log n) O(nlogn)(而不是期望 O ( n log ⁡ n ) O(n\log n) O(nlogn))的。

在伸展树上的一般操作都基于伸展操作:假设想要对一个二叉查找树执行一系列的查找操作,为了使整个查找时间更小,被查频率高的那些条目就应当经常处于靠近树根的位置。于是想到设计一个简单方法, 在每次查找之后对树进行重构,把被查找的条目搬移到离树根近一些的地方。伸展树应运而生。伸展树是一种自调整形式的二叉查找树,它会沿着从某个节点到树根之间的路径,通过一系列的旋转把这个节点搬移到树根去。

函数原型

void rotate(int);

当树是完全二叉树时,单次查询复杂度为 O ( log ⁡ n ) O(\log n) O(logn)
当树是一条链时,单次查询复杂度为 O ( n ) O(n) O(n)
rotate通过改变树的形态,达到使得Splay的均摊复杂度为 O ( log ⁡ n ) O(\log n) O(logn)的目的。

函数rotate(u)将编号为u的节点旋转一次。

旋转原理

首先我们需要记录一个变量k

  • k=get(u)

这表明了编号为u的节点是其父亲的哪侧儿子,k=0表示左儿子,k=1表示右儿子。

旋转过程需要保存几个节点编号:

  • u:当前节点
  • fa:当且节点的父亲
  • son:节点t[u]的异侧儿子,即son=t[u].ch[k^1]。例如:如果t[u]t[fa]的左儿子,那么t[son]就是t[u]的右儿子。
  • ffa:当前节点的父亲的父亲。

画出一个图来示意一下:
在这里插入图片描述
在这里,t[fa]t[ffa]的哪侧儿子无关紧要。

接下来我们修改树的形态,完成三步操作:

  1. u顶替掉原来fa的位置: 把u设置为ffa的儿子,fa是哪侧儿子,u就是哪侧儿子。
  2. fa顶替掉原来son的位置:fa变成uk^1儿子
  3. son设为fa的同侧儿子,替代uson变成fak儿子

还是看代码比较好懂:

int k=get(u),son=t[u].ch[k^1],fa=t[u].fa,ffa=t[fa].fa;
add(ffa,u,get(fa));
add(u,fa,k^1);
add(fa,son,k);

画个图:
在这里插入图片描述

直接背下来写得比较快。

旋转的性质

二叉查找树的性质:中序遍历是一个递增序列。

旋转的性质:旋转不会改变树的中序遍历。(显然)

旋转实现

完整代码是这样的:

void rotate(int u) {
	int k=get(u),son=t[u].ch[k^1],fa=t[u].fa,ffa=t[fa].fa;
	add(ffa,u,get(fa));
	add(u,fa,k^1);
	add(fa,son,k);
	push_up(fa);
	push_up(u);
}

注意最后要更新节点信息。先push_up父亲,再push_up自身,因为此时,原来的父亲是自身的儿子。

保证编号为u的节点存在父亲。
(事实上,可能会有son=0ffa=0,使得编号为0的节点可能携带有额外的祖先/后代信息,但是这不影响。)

其实我们还可以选择把子孙转成指定祖先的儿子处就停止,这里不多说了。

伸展(splay)

函数原型:

int splay(int);

伸展操作是执行若干次旋转操作,把编号为u的节点旋转到根,并返回u的编号。

执行的方法是这样的:

记录当且节点的编号u,更新它目前的父亲编号fa=t[u].fa,注意u的父亲是不断变化的,因此要更新:

  1. 如果u没有父亲,说明u是根节点:停止
  2. 如果fa不存在父亲,说明u再旋转一次就会旋转到根:rotete(u)
  3. get(fa)==get(u),说明ufa是同侧儿子,先旋转fa,再旋转urotate(fa),rotate(u)
  4. get(fa)!=get(u),说明ufa是异侧儿子,旋转两次urotate(u),rotate(u)

写成代码是这样的:

int splay(int u) {
	for(int fa; (fa=t[u].fa); rotate(u))
		if(t[fa].fa)
			rotate(get(u)==get(fa)?fa:u);
	return root=u;
}

注意最后把根节点编号设为u

伸展主要有三个作用:

  1. 可以保证时间复杂度
  2. rotate内有push_up函数,如果修改了u的信息,伸展一下可以更新到根节点的链上信息
  3. u旋转到根便于下一步操作

加入值(push)

函数原型:

int push(int);

函数push(val)val在集合中出现的次数增加1,并返回val所在的节点编号,如果val在集合中原来并不存在,就创建一个新节点。

函数分为三种情况讨论:

  1. Splay为空:直接新建一个节点,然后把根设为这个节点。
  2. Splay中以前存在val这个值:找到存储这个值的节点,先把它旋转到根,然后把它的cnt增加1,push_up以更新信息
    (因为此时这个节点已经是根了,对它调用splay不会rotate,因此必须手动psuh_up
    即使我们先前不把这个节点旋转到根,但是这个节点可能原本就是根,还是需要更新一下siz信息)
  3. Splay中不存在val这个值:找到一个合适的叶子节点,然后对val新建一个节点,并且把新节点的父亲设为这个叶子节点。把这个节点旋转到根。

为了保证时间复杂度,同时为了更新链上记录的siz信息,最后都要把val所在的节点旋转到根。

函数定义:

int push(int val) {
	if(!root) {
		t[++tot].set(val);
		return root=tot;
	}
	int x=val_find(val);这里的val_find函数很特殊,如果找到val,会返回这个节点作为根节点,否则会返回一个可以作为新节点父亲的叶子节点
	if(t[x].val==val) {
		t[x].cnt++;
		push_up(x);
		return x;
	}
	t[++tot].set(val);要先set再加边,否则set会将t[tot]上存储的祖先/子孙信息清除
	add(x,tot,t[x].val<val);
	return splay(tot);
}

删去值(pop)

函数原型:

void pop(int);

函数pop(val)将集合中val出现的次数减1,保证val之前至少出现过一次。

函数分几种情况讨论:
首先找到val所在的节点的编号,设为u,然后把这个节点旋转到根。

  1. 如果t[u].cnt>1:直接让cnt--
  2. 如果u至少没有一个儿子,那就把根设为它的另一个儿子,然后删除u
    (如果u没有任何一个儿子是不影响的。)
  3. 否则,说明u既有左儿子,又有右儿子,也就是说val既有前驱又有后继:
    因此找到val的前驱,把前驱旋转到根,此时u一定是根的右儿子,而且由于根是前驱,所以u没有左儿子,因此直接把u的右儿子设为根的右儿子,然后删除u即可。

注意最后要push_up(root),因为第1,3种情况下需要更新根节点信息。

函数实现:

void pop(int val) {
	int u=val_find(val);
	if(t[u].cnt>1) t[u].cnt--;
	else if(!t[u].l||!t[u].r) root=t[u].l|t[u].r,del(u);
	else {
		pre(val);
		int r=t[u].r;
		del(u);这里要先清除u,再连边。否则清除u时会顺便擦除根节点和r节点的祖先关系信息
		add(root,r,1);此时前驱是根节点,把u的右儿子设为其前驱的右儿子
	}
	push_up(root);
}

用值查找(val_find)

函数原型:

int val_find(int);

函数val_find(val)在集合中查找值val,如果它出现过,那就把val所在的节点旋转到根,并且返回它的编号,如果它没有出现过,那就返回一个可以作为val父亲的叶子节点编号。
(如果此时树为空,函数会返回0,尽管不会出现这样的调用)

主要做法就是从根节点开始找,如果找到了就返回,没找到就按照大小关系继续往下走。
如果找到叶子节点还没找到val就返回它的父亲。

函数定义:

int val_find(int val) {
	int u=root,fa=0;
	while(u)
		if(t[fa=u].val==val) return splay(u);
		else u=t[u].ch[t[u].val<val];
	return fa;
}

用排名查找(rank_find)

函数原型:

int rank_find(int,int);

函数rank_find(u,rank)查找u子树内排名rank的节点,并返回节点编号。注意这里是子树内排名,而不是全局排名。

我们通常调用时参数u=root,即查询全局排名。
rank_find函数设计为两个参数,一方面是为了方便递归调用,另一方面,不为其提供一个参数的重载版本是为了防止将其与val_find函数与find_rank函数混淆。

rank_find(u,rank)函数这样设计:
分情况讨论:

  1. 如果rank<=左子树大小,递归到左儿子:rank_find(t[u].l,rank)
  2. 否则,如果rank>左子树大小+自身节点的cnt,递归到右儿子:rank_find(t[u].r,rank-t[t[u].l].siz-t[u].cnt)
  3. 否则:旋转并且返回自身节点编号

这种独特的递归顺序使得如果查询的rank大于子树之内的最大排名,会返回子树最大值的节点编号,避免了进一部的分情况讨论。

函数定义:

int rank_find(int u,int rank) {
	int l=t[t[u].l].siz;这样可以少打很多字
	if(rank<=l) return rank_find(t[u].l,rank);
	else if(rank>l+t[u].cnt) return rank_find(t[u].r,rank-l-t[u].cnt);
	return splay(u);
}

查询值的排名(find_rank)

函数原型:

int find_rank(int);

函数find_rank(val)查询值val的排名,不保证val出现过。
没有提供查询节点排名的函数是因为节点不存在排名,如果想要查询节点u对应的权值的排名,可以调用find_rank(t[u].val)

查询val的排名,可以通过把val加入集合一次,然后把它对应的节点旋转到根。那么val的排名就是它对应节点的左子树的大小+1
然后再把val在集合中删去一次。

函数定义:

int find_rank(int val) {
	int ans=t[t[push(val)].l].siz+1;
	pop(val);
	return ans;
}

查找前驱/后继(bound)

函数原型:

int bound(int,bool);

函数bound(val,k)用于查询前驱/后继,旋转节点到根,并返回对应的节点编号。
函数bound(val,0)用于查询值val的前驱。
函数bound(val,1)用于查询值val的后继。

bound原理

这里以查询前驱举例:
查询val前驱的方法就是,无论Splay中是否存在val,我们都先push(val),这样Splay内肯定存在val,且为Splay的根。
走到根的左儿子上,然后不断地走右儿子,直到走到叶子节点即为前驱,记录答案后pop(val)

查询后继的方法是类似的:先push(val),走到根的右儿子上,然后不断地走左儿子,叶子节点即为前驱,记录答案后pop(val)

注意到可以把这两种情况合并起来:设k=0表示查询前驱,k=1表示查询后继,则函数定义如下:

int bound(int val,bool k) {
	int u=t[push(val)].ch[k];
	while(t[u].ch[k^1]) u=t[u].ch[k^1];
	pop(val);
	return splay(u);
}

前驱(pre)

函数原型:

int pre(int);

pre为查询前驱提供了专门的接口。
函数pre(val)表示查询val的前驱,把前驱旋转到根,并且返回前驱编号。

val可以比集合中的任何数都要大,但是不能没有前驱,否则运行可能出现问题,我们没有保证splay(0)不会出错,因为我们没有保证t[0]不携带非零的祖先后代信息。

如果非要这样查询可能没有前驱/后继的数的话可以设置哨兵:push(-INF),push(INF)

函数定义:

int pre(int val) {
	return bound(val,0);
}

后继(nxt)

函数原型:

int nxt(int);

函数nxt(val)表示查询val的后继,把后继旋转到根,并返回后继编号。
必须要保证val有后继。

函数定义:

int nxt(int val) {
	return bound(val,1);
}

完整代码

空间复杂度

注意到Splay的任意一种操作至多创建一个节点,因此空间复杂度为一倍操作次数。(本题要算上一开始的 1 0 5 10^5 105次操作)

代码

#include<iostream>
using namespace std;
const int N=2e6;
struct node {
	int fa,ch[2];
	int val,cnt,siz;
	int &l=ch[0],&r=ch[1];
	void set(int v,int c=1,int s=1) {
		l=r=fa;
		val=v;
		cnt=c;
		siz=s;
	}
}t[1100005];
int tot,root;
bool get(int);
void push_up(int);
void add(int,int,bool);
void del(int);
void rotate(int);
int splay(int);
int push(int);
void pop(int);
int val_find(int);
int rank_find(int,int);
int find_rank(int);
int bound(int,bool);
int pre(int);
int nxt(int);
int a[N+5];
int main() {
	int n,m;
	cin>>n>>m;
	for(int i=1;i<=n;i++) cin>>a[i];
	for(int i=1;i<=n;i++) push(a[i]);
	int ans=0,last=0;
	while(m--) {
		int op,x;
		cin>>op>>x;
//		if(op==1) push(x);
//		if(op==2) pop(x);
//		if(op==3) cout<<find_rank(x)<<endl;
//		if(op==4) cout<<t[rank_find(root,x)].val<<endl;
//		if(op==5) cout<<t[pre(x)].val<<endl;
//		if(op==6) cout<<t[nxt(x)].val<<endl;
		x^=last;
		if(op==1) push(x);
		if(op==2) pop(x);
		if(op==3) ans^=(last=find_rank(x));
		if(op==4) ans^=(last=t[rank_find(root,x)].val);
		if(op==5) ans^=(last=t[pre(x)].val);
		if(op==6) ans^=(last=t[nxt(x)].val);
	}
	cout<<ans;
}
bool get(int u) {
	return t[t[u].fa].r==u;
}
void push_up(int u) {
	t[u].siz=t[t[u].l].siz+t[t[u].r].siz+t[u].cnt;
}
void add(int fa,int son,bool k) {
	t[t[son].fa=fa].ch[k]=son;
}
void del(int u) {
	t[t[u].l].fa=t[t[u].r].fa=t[t[u].fa].ch[get(u)]=0;
	t[u].set(0,0,0);
}
void rotate(int u) {
	int k=get(u),son=t[u].ch[k^1],fa=t[u].fa,ffa=t[fa].fa;
	add(ffa,u,get(fa));
	add(u,fa,k^1);
	add(fa,son,k);
	push_up(fa);
	push_up(u);
}
int splay(int u) {
	for(int fa;(fa=t[u].fa);rotate(u)) 
		if(t[fa].fa)
			rotate(get(fa)==get(u)?fa:u);
	return root=u;
}
int push(int val) {
	if(!root) {
		t[++tot].set(val);
		return root=tot;
	}
	int x=val_find(val) ;
	if(t[x].val==val) {
		t[x].cnt++;
		push_up(x);
		return x;
	}
	t[++tot].set(val);
	add(x,tot,t[x].val<val);
	return splay(tot);
}
void pop(int val) {
	int u=val_find(val);
	if(t[u].cnt>1) t[u].cnt--;
	else if(!t[u].l||!t[u].r) root=t[u].l|t[u].r,del(u);
	else {
		pre(val);
		int r=t[u].r;
		del(u);
		add(root,r,1);
	}
	push_up(root);
}
int val_find(int val) {
	int u=root,fa=0;
	while(u) 
		if(t[fa=u].val==val) return splay(u);
		else u=t[u].ch[t[u].val<val];
	return fa;
}
int rank_find(int u,int rank) {
	int l=t[t[u].l].siz;
	if(rank<=l) return rank_find(t[u].l,rank);
	else if(rank>t[u].cnt+l) return rank_find(t[u].r,rank-t[u].cnt-l);
	return splay(u);
}
int find_rank(int val) {
	int ans=t[t[push(val)].l].siz+1;
	pop(val);
	return ans;
}
int bound(int val,bool k) {
	int u=t[push(val)].ch[k];
	while(t[u].ch[k^1]) u=t[u].ch[k^1];
	pop(val);
	return splay(u);
}
int pre(int val) {
	return bound(val,0);
}
int nxt(int val) {
	return bound(val,1);
}

完整版splay

完整代码:

#include<iostream>
using namespace std;
const int N=1.1e6;
struct node{
	int ch[2],fa;
	int val,cnt,siz;
	int&l=ch[0],&r=ch[1];
	void set(int v,int c=1,int s=1){
		l=r=fa=0;
		val=v;
		cnt=c;
		siz=s;
	}
}t[N+5];
int tot,root;
bool get(int);
void push_up(int);
void add(int,int,bool);
void del(int);
void rotate(int);
int splay(int,int=0);
int push(int);
void pop(int);
int val_find(int);
int rank_find(int,int);
int find_rank(int);
int bound(int,bool);
int pre(int);
int nxt(int);
int main(){
	int n,m;
	cin>>n>>m;
	push(2147483647);
	push(-2147483648);
	for(int i=1,x;i<=n;i++) cin>>x,push(x);
	int ans=0,lst=0;
	while(m--){
		int op,x;
		cin>>op>>x;
		x^=lst;
		if(op==1) push(x);
		if(op==2) pop(x);
		if(op==3) ans^=lst=find_rank(x)-1;
		if(op==4) ans^=lst=t[rank_find(root,x+1)].val;
		if(op==5) ans^=lst=t[splay(pre(x))].val;
		if(op==6) ans^=lst=t[splay(nxt(x))].val;
		
// 		if(op==1) push(x);
// 		if(op==2) pop(x);
// 		if(op==3) cout<<find_rank(x)-1<<endl;
// 		if(op==4) cout<<t[rank_find(root,x+1)].val<<endl;
// 		if(op==5) cout<<t[splay(pre(x))].val<<endl;
// 		if(op==6) cout<<t[splay(nxt(x))].val<<endl;
	}
	cout<<ans;
}
bool get(int u){
	return t[t[u].fa].r==u;
}
void push_up(int u){
	t[u].siz=t[t[u].l].siz+t[t[u].r].siz+t[u].cnt;
}
void add(int fa,int u,bool k){
	t[t[u].fa=fa].ch[k]=u;
}
void del(int u){
	t[t[u].l].fa=t[t[u].r].fa=t[t[u].fa].ch[get(u)]=0;
	t[u].set(0,0,0);
}
void rotate(int u){
	int k=get(u),fa=t[u].fa,ffa=t[fa].fa,son=t[u].ch[k^1];
	add(ffa,u,get(fa));
	add(u,fa,k^1);
	add(fa,son,k);
	push_up(fa);
	push_up(u);
}
int splay(int u,int v){
	for(int fa;(fa=t[u].fa)^v;rotate(u))
		if(t[fa].fa^v)
			rotate(get(fa)==get(u)?fa:u);
	!v&&(root=u);
	return u;
}
int push(int val){
	if(!root){
		t[++tot].set(val);
		return root=tot;
	}
	int x=val_find(val);
	if(t[x].val==val){
		t[x].cnt++;
		push_up(x);
		return x;
	}
	t[++tot].set(val);
	add(x,tot,t[x].val<val);
	return splay(tot);
}
void pop(int val){
	int u=val_find(val);
	if(t[u].cnt>1) t[u].cnt--;
	else {
		int Pre=pre(val),Nxt=nxt(val);
		splay(Pre);
		splay(Nxt,Pre);
		del(u);
		splay(Nxt);
	}
	push_up(root);
}
int val_find(int val){
	int u=root,fa;
	while(u)
		if(t[fa=u].val==val)
			return splay(u);
		else 
			u=t[u].ch[t[u].val<val];
	return fa;
}
int rank_find(int u,int rank){
	int l=t[t[u].l].siz;
	if(rank<=l) return rank_find(t[u].l,rank);
	if(l+t[u].cnt<rank) return rank_find(t[u].r,rank-t[u].cnt-l);
	return splay(u);
}
int find_rank(int val){
	int ans=t[t[push(val)].l].siz+1;
	pop(val);
	return ans;
}
int bound(int val,bool k){
	int u=t[push(val)].ch[k];
	while(t[u].ch[k^1])
		u=t[u].ch[k^1];
	pop(val);
	return u;
}
int pre(int val){
	return bound(val,0);
}
int nxt(int val){
	return bound(val,1);
}

伸展(splay)

函数原型:

int splay(int,int=0);

函数splay(u,v)的功能是把节点u旋转为v的儿子,如果v0,就旋转到根,然后返回u节点的编号。这是完整的伸展功能。

原理是一样的,代码如下:

int splay(int u,int v) {
	for(int fa;(fa=t[u].fa)^v;rotate(u))
		if(t[fa].fa^v)
			rotate(get(fa)==get(u)?fa:u);
	if(!v)
		root=u;
	return u;
}

删除

如果我们选择添加两个哨兵 − ∞ , + ∞ -\infty,+\infty ,+,那我们可以认为任何一个真正存在于集合之内的元素都有前驱/后继。这个时候就有了Splay的一个经典操作:夹挤。

如果我们要删除val,我们可以把val的前驱旋转到根,然后再把val的后继旋转为前驱的儿子,此时代表val的节点就是其后继的左儿子,并且显然val没有任何一个儿子,此时我们可以直接删除val,这样就再次减少了特判:

void pop(int val) {
	int u=val_find(val);
	if(t[u].cnt>1) t[u].cnt--;
	else {
		int x=pre(val),y=nxt(val);
		splay(x);
		splay(y,x);
		del(u);
		splay(y);
	}
	push_up(root);
}

不过如果我们要这样写的话,就不能在pre,nxt函数内部把对应的节点旋转到根了。但是我们仍然要注意,每次查询前驱/后继都必须要splay,不然复杂度不对。因此我们在每一次查询前驱/后继的时候必须手动splay,对应的则是prenxtmain函数的修改:

int bound(int val,bool k) {
	int u=t[push(val)].ch[k];
	while(t[u].ch[k^1]) u=t[u].ch[k^1];
	pop(val);
	return u;
}
int pre(int val) {
	return bound(val,0);
}
int nxt(int val) {
	return bound(val,1);
}
int main() {
	int n,m;
	cin>>n>>m;
	push(2147483647);注意P6136的值域为2^30,因此不能选择1e9作为inf
	push(-2147483647);
	for(int i=1,x;i<=n;i++) cin>>x,push(x);
	int ans=0,last=0;
	while(m--) {
		int op,x;
		cin>>op>>x;
		x^=last;
		if(op==1) push(x);
		if(op==2) pop(x);
		if(op==3) ans^=(last=find_rank(x)-1);
		if(op==4) ans^=(last=t[rank_find(root,x+1)].val);
		if(op==5) ans^=(last=t[splay(pre(x))].val);
		if(op==6) ans^=(last=t[splay(nxt(x))].val);
//		if(op==1) push(x);
//		if(op==2) pop(x);
//		if(op==3) cout<<find_rank(x)-1<<endl;
//		if(op==4) cout<<t[rank_find(root,x+1)].val<<endl;
//		if(op==5) cout<<t[splay(pre(x))].val<<endl;
//		if(op==6) cout<<t[splay(nxt(x))].val<<endl;
	}
	cout<<ans;
}

前驱/后继

查询前驱/后继,可以采用刚才的办法,也可以采用另一种办法,即:
查询前驱即查询比val排名少 1 1 1的值,使用rank_find即可完成。查询后继同理。

看似可以不写bound,能够减小码量。不过事实上实现起来并不比bound简单。

后话

关于pop和pre

有一种观点认为,对pre函数查询不在集合里面的val会导致创建新节点,而删除val时又有可能导致查询val的前驱,这可能会导致循环调用。

但是这种说法是错误的,因为事实上,如果在pop(val)时调用pre(val),进而导致了一次push(val)后再pop(val),此时val对应节点的cnt至少为2了,所以在本层pop(val)不会调用pre(val),而是会将cnt--

指针版Splay

完整代码:

#include<iostream>
#include<cstdio>
using namespace std;
const int N=1.1e6;
class node {
	public:
		int val,cnt,siz;
		node*fa,*ch[2];
		node*&l=ch[0],*&r=ch[1];
		node();
		void set(int,int=1,int=1);
} t[N+5];
node*nt=&t[0];
node::node(){
	val=cnt=siz=0;
	fa=l=r=nt;
}
void node::set(int v,int c,int s) {
	l=r=fa=nt;
	val=v,cnt=c,siz=s;
}
int tot;
node*root=nt;
bool get(node*);
void push_up(node*);
void add(node*,node*,bool);
void del(node*);
void rotate(node*);
node*splay(node*,node* =nt);
node*push(int);
void pop(int);
node*val_find(int);
node*rank_find(node*,int);
int find_rank(int);
node*bound(int,bool);
node*pre(int);
node*nxt(int);

//void check(node*u=root){
//	printf("%lld(%lld,%lld):%d %d %d\n",u-nt,u->l-nt,u->r-nt,u->val,u->cnt,u->siz);
//	if(u->l!=nt)check(u->l); 
//	if(u->r!=nt)check(u->r); 
//}
int main() {
	push(2147483647);
	push(-2147483648);
//	check();
	int n,m;
	cin>>n>>m;
	for(int i=1,x; i<=n; i++)cin>>x,push(x);
	int ans=0,last=0;
	while(m--) {
		int op,x;
		cin>>op>>x;
		x^=last;
//		if(op==1) push(x);
//		if(op==2) pop(x);
//		if(op==3) cout<<find_rank(x)-1<<endl;
//		if(op==4) cout<<rank_find(root,x+1)->val<<endl;
//		if(op==5) cout<<splay(pre(x))->val<<endl;
//		if(op==6) cout<<splay(nxt(x))->val<<endl;
		if(op==1) push(x);
		if(op==2) pop(x);
		if(op==3) ans^=(last=find_rank(x)-1);
		if(op==4) ans^=(last=rank_find(root,x+1)->val);
		if(op==5) ans^=(last=splay(pre(x))->val);
		if(op==6) ans^=(last=splay(nxt(x))->val);
	}
	cout<<ans;
}
bool get(node*u) {
	return u->fa->r==u;
}
void push_up(node*u) {
	u->siz=u->l->siz+u->r->siz+u->cnt;
}
void add(node*fa,node*u,bool k) {
	(u->fa=fa)->ch[k]=u;
}
void del(node*u) {
	u->l->fa=u->r->fa=u->fa->ch[get(u)]=nt;
	u->set(0,0,0);
}
void rotate(node*u) {
	node*fa=u->fa,*ffa=fa->fa;
	bool k=get(u);
	node*son=u->ch[k^1];
	add(ffa,u,get(fa));
	add(u,fa,k^1);
	add(fa,son,k);
	push_up(fa);
	push_up(u);
}
node*splay(node*u,node*v) {
	for(node*fa; (fa=u->fa)!=v; rotate(u))
		if(fa->fa!=v)
			rotate(get(fa)==get(u)?fa:u);
	if(v==nt)
		root=u;
	return u;
}
node*push(int val) {
	if(root==nt) {
		root=&t[++tot];
		root->set(val);
		return root;
	}
	node*x=val_find(val);
	if(x->val==val) {
		x->cnt++;
		push_up(x);
		return x;
	}
	node*y=&t[++tot];
	y->set(val);
	add(x,y,x->val<val);
	return splay(y);
}
void pop(int val) {
	node* u=val_find(val);
	if(u->cnt>1) u->cnt--;
	else {
		node* pr=pre(val),*nx=nxt(val);
		splay(pr);
		splay(nx,pr);
		del(u);
		push_up(pr);
	}
	push_up(root);
}
node*val_find(int val) {
	node*u=root,*fa;
	while(u!=nt)
		if((fa=u)->val==val)
			return splay(u);
		else
			u=u->ch[u->val<val];
	return fa;
}
node*rank_find(node*u,int rank) {
	if(rank<=u->l->siz) return rank_find(u->l,rank);
	else if(rank>u->l->siz+u->cnt) return rank_find(u->r,rank-u->l->siz-u->cnt);
	return splay(u);
}
int find_rank(int val) {
	int ans=push(val)->l->siz+1;
	pop(val);
	return ans;
}
node*bound(int val,bool k) {
	node*u=push(val)->ch[k];
	while(u->ch[k^1]!=nt)
		u=u->ch[k^1];
	pop(val);
	return u;
}
node*pre(int val) {
	return bound(val,0);
}
node*nxt(int val) {
	return bound(val,1);
}

指针写起来还是方便一些,主要的不同之处就是判断是否走到了空节点不再是if(!u),而是if(u==t[0]),并且必须要提供显式的构造函数,确保所有指向空节点的指针指向的是t[0]

文艺平衡树(未完成)

模板题

权值Splay节点对应一个真实值。
区间Splay(文艺平衡树)节点对应一个下标。

一个节点在Splay中对应的下标并不是固定的,而是由其在Splay中的位置决定。如果它在平衡树中序遍历中的位置为 k k k,换句话说这个节点的排名为 k k k,那么它对应的下标就是 a k a_k ak,表示 a k = v a l a_k=val ak=val

因此文艺平衡树中val值不满足二叉搜索树的性质。因为区间Splay本来就不是平衡二叉搜索树,而是平衡二叉区间树。

Splay的旋转操作不会改变平衡树中序遍历,因此仍然可以使用旋转操作维持Splay的平衡。

Splay区间插入/删除什么的都根据排名做操作。其他操作维护懒标记即可。比如区间翻转懒标记。

更详细的东西有时间再写。

#include<iostream>
using namespace std;
const int N=1e5;
struct node {
	int fa,ch[2],val,siz,tag,&l=ch[0],&r=ch[1];
	void set(int v,int s=1) {
		fa=l=r=0;
		val=v;
		siz=s;
		tag=0;
	}
} t[N+5];
int root,tot;
bool get(int);
void push_up(int);
void push_down(int);
void add(int,int,bool);
void rotate(int);
int splay(int,int=0);
int push(int);
void pop(int);
int val_find(int);
int rank_find(int,int);
void check(int u) {
	cout<<u<<"("<<t[u].l<<','<<t[u].r<<") "<<t[u].val<<' '<<t[u].siz<<' '<<t[u].tag<<endl;
	if(t[u].l) check(t[u].l);
	if(t[u].r) check(t[u].r);
}
int main() {
	int n,m;
	cin>>n>>m;
	push(0);
	push(n+1);
	for(int i=1; i<=n; i++) push(i);
	while(m--) {
		int l,r;
		cin>>l>>r;
		int x=rank_find(root,l),y=rank_find(root,r+2);
		splay(x);
		splay(y,x);
		t[t[y].l].tag^=1;
	}
//	cout<<"***"<<endl;
//	check(root);
	for(int i=1; i<=n; i++) {
//		cout<<"***"<<endl;
		cout<<t[rank_find(root,i+1)].val<<' ';
//		check(root);
	}
}
bool get(int u) {
	return t[t[u].fa].r==u;
}
void push_up(int u) {
	t[u].siz=t[t[u].l].siz+t[t[u].r].siz+1;
}
void push_down(int u) {
	if(t[u].tag) {
		int&l=t[u].l,&r=t[u].r;
		swap(l,r);

		t[l].tag^=1;
		t[r].tag^=1;
		t[u].tag=0;
	}
}
void add(int fa,int u,bool k) {
	t[t[u].fa=fa].ch[k]=u;
}
void rotate(int u) {
	int k=get(u),fa=t[u].fa,ffa=t[fa].fa,son=t[u].ch[k^1];
	add(ffa,u,get(fa));
	add(u,fa,k^1);
	add(fa,son,k);
	push_up(fa);
	push_up(u);
}
int splay(int u,int v) {
	for(int fa; (fa=t[u].fa)^v; rotate(u))
		if(t[fa].fa^v)
			rotate(get(fa)==get(u)?fa:u);
	if(!v) root=u;
	return u;
}
int push(int val) {
	if(!root) {
		t[root=++tot].set(val);
		return root;
	}
	int x=val_find(val);
	t[++tot].set(val);
	add(x,tot,t[x].val<val);
	return splay(tot);
}
int val_find(int val) {
	int u=root;
	while(t[u].ch[t[u].val<val])
		u=t[u].ch[t[u].val<val];
	return u;
}
int rank_find(int u,int rank) {
	push_down(u);
	int l=t[t[u].l].siz;
	if(rank<=l) return rank_find(t[u].l,rank);
	else if(l+1<rank) return rank_find(t[u].r,rank-l-1);
	return splay(u);
}

后记

于是皆大欢喜。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值