splay平衡树
简介
二叉查找树满足任意一个节点,它的左儿子的权值
<
\lt
<自己的权值
<
\lt
<右儿子的权值。平衡树在满足二叉查找树条件的情况下各点尽可能分布均匀,使时间常数最小化。
splay平衡树是一种按值排序时可以实现所有普通平衡树的操作,按位置排序时可以实现区间翻转和平移的多功能平衡树。
变量声明
f[i]//i的父节点
ch[i][0]//i的左儿子
ch[i][1]//i的右儿子
key[i]//i的关键字(一般为i代表的那个数字)
cnt[i]//i节点关键字出现的次数
siz[i]//i这个子树的大小(下面有多少个结点)
sz//整棵树的大小(可以参考链向前式星那个表示总边数的变量)
root//整棵树的根
基础操作
下面是几个简单的基本操作:
clear操作
作用:将当前点的各项值都清零(删除之后清理数据)。
inline void clear( int x ) {
ch[x][0] = ch[x][1] = f[x] = cnt[x] = key[x] = siz[x] = 0;
}
get操作
作用:判断当前的点是它父亲结点的左儿子还是右儿子。
inline int get( int x ) {
return ch[f[x]][1] == x;//0左1右
}
update操作
作用:更新当前结点的siz
值(发生修改后用)。
inline void update( int x ) {
if( x ) {
siz[x] = cnt[x];
if( ch[x][0] ) {
siz[x] += siz[ch[x][0]];
}
if( ch[x][1] ) {
siz[x] += siz[ch[x][1]];
}
}
}
rotate操作
作用:一个利用splay左右儿子标示的特殊性而把左旋右旋结合在一起的操作,使得平衡树更加平衡。
分析:
这是原来的树,我们现在要从点D开始rotate,让调换位置D和它的父结点B,我们在此列出了B,D以及所有可能被影响到的其它点。
这是我们希望的交换后的效果图。
我们先设变量which
等于get( x )
,其中x
代表点D,用which
表示D是左儿子还是右儿子。
我们通过观察,不难发现图中B的which
儿子,D的which ^ 1
儿子,以及B的父结点A(如果有的话)对应的左或右(当场判断)儿子也要换。
总而言之,我们先连BG,再连DB,最后连AD即可。并注意按顺序更新B,D的siz
值。
inline void rotate( int x ) {
int old = f[x] , oldf = f[old] , which = get( x );
ch[old][which] = ch[x][which ^ 1];
f[ch[old][which]] = old;
ch[x][which ^ 1] = old;
f[old] = x;
f[x] = oldf;
if( oldf ) {
ch[oldf][ch[oldf][1] == old] = x;
}
update( old );
update( x );
return;
}
splay操作
作用:rotate的发展,本质是不停地rotate,一直splay到根。
splay的过程中我们要分类讨论:
情况一:三点一线(x
,x
的父结点,x
的祖父结点在一条),先rotatex
的父结点,再rotatex
本身。否则会形成单旋使平衡树失衡;
情况二:没有三点一线,rotatex
即可。
inline void splay( int x ) {
for ( int fa ; fa = f[x] ; rotate(x) ) {
if( f[fa] ) {
rotate( ( get( x ) == get( fa ) ? fa : x ) );
}
}
root = x;
}
insert操作
作用:插入数据。
这里也要分类讨论:
如果root=0
,即树为空,我们处理几个变量后直接返回;
否则,我们按照二叉查找树的性质一直往下找,其中:
若当前结点的关键字和要插入的点一样的话,把这个点加一个权值,更新一下当前点和父结点的siz
和cnt
,再splay上去;
若到了最底下,直接插入,整棵树大小sz
加一,新结点的各项值更新一下(父,左右儿子,权值,大小),更新一下当前点的父结点的siz
,再splay上去。
inline void insert( int v ) {
if( root == 0 ) {
++sz;
root = sz;
ch[root][0] = ch[root][1] = f[root] = 0;
key[root] = v;
cnt[root] = siz[root] = 1;
return;
}
int cur = root , fa = 0;
while ( true ) {
if( key[cur] == v ) {
++cnt[cur];
update( cur );
update( fa );
splay( cur );//一边rotate,一边往上传值
break;
}
fa = cur;
cur = ch[cur][key[cur] < v];
if( cur == 0 ) {
++sz;
ch[sz][0] = ch[sz][1] = 0;
key[sz] = v;
siz[sz] = 1;
cnt[sz] = 1;
f[sz] = fa;
ch[fa][key[fa] < v] = sz;
update( fa );
splay( sz );
break;
}
}
}
find操作
作用:查询关键字为v
时的排名。
一开始ans
为零,当前点为root
。
如果v
比当前结点的关键字小,则应该向左子树寻找,ans
不变;
如果v
比当前结点的关键字大,则应该向右子树寻找,ans
加上左子树的siz
和当前点的cnt
。
找到之后ans
加一。
最后要splay,别的操作有用。
inline int find( int v ) {
int ans = 0 , cur = root;
while ( true ) {
if( v < key[cur] ) {
cur = ch[cur][0];
} else {
ans += ( ch[cur][0] ? siz[ch[cur][0]] : 0 );
if( v == key[cur] ) {
splay( cur );
return ans + 1;
}
ans += cnt[cur];
cur = ch[cur][1];
}
}
}
findth操作
作用:查询排名为k
的点。
一开始当前点为root
。
如果当前点有左子树,并且k
小于左子树大小时,可以向左子树寻找;
否则,先用tem
表示左子树的siz
(没有则为零)和当前点的cnt
,看看排名为k
的点是否为当前点,然后k
去减tem
,从右子树开始找。
inline int findth( int k ) {
int cur = root;
while ( true ) {
if( ch[cur][0] && k <= siz[ch[cur][0]] ) {
cur = ch[cur][0];
} else {
int tem = ( ch[cur][0] ? siz[ch[cur][0]] : 0 ) + cnt[cur];
if( cur <= tem ) {
return key[cur];
}
k -= tem;
cur = ch[cur][1];
}
}
}
pre/nxt操作
作用:pre找前驱,nxt找后继。
splay平衡树找前驱后继的思路是先插入被查找数(已被splay到根结点),前驱就是根节点左子树最右边的结点(最大的小于),后继就是根节点右子树最左边的结点(最小的大于),查找完后再删除被查找数。
inline int pre() {
int cur = ch[root][0];
while ( ch[cur][1] ) {
cur = ch[cur][1];
}
return cur;
}
inline int nxt() {
int cur = ch[root][1];
while ( ch[cur][0] ) {
cur = ch[cur][0];
}
return cur;
}
int main() {
insert( x );
pre();
del( x );
}
int main() {
insert( x );
nxt();
del( x );
}
del操作
作用:删除关键字为v
的点。
这个操作比较麻烦,我们先find一下v
,把它旋到根。
接下来我们要分多种情况讨论:
case 1:cnt[root]>1
,不止有一个,直接减一;
case 2:root
只有一个但没有子结点,直接clear;
case 3:如果root
只有左儿子或只有右儿子,直接删了root
,唯一的儿子做root
;
case 4:root
有两个儿子,我们要拿root
的前驱作新根,将原先root
的右子树接到新root
的右子树上(由于选的是前驱,原先root
一定没有左子树)。
删完后不忘update。
inline void del( int v ) {
find( v );
if( cnt[root] > 1 ) {
--cnt[root];
update( root );
return;
}
if( !ch[root][0] && !ch[root][1] ) {
clear( root );
root = 0;
sz = 0;
return;
}
if( !ch[root][0] ) {
int oldroot = root;
root = ch[root][1];
f[root] = 0;
clear( oldroot );
--sz;
return;
} else if( !ch[root][1] ) {
int oldroot = root;
root = ch[root][0];
f[root] = 0;
clear( oldroot );
--sz;
return;
}
int lpre = pre() , oldroot = root;
splay( lpre );
f[ch[oldroot][1]] = root;
ch[root][1] = ch[oldroot][1];
clear( oldroot );
update( root );
return;
}
完整代码
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#define ll long long
#define N 100005
using namespace std;
int t , opt;
int sz = 0 , root = 0;
int ch[N][2] , f[N] , cnt[N] , key[N] , siz[N];
inline void clear( int x ) {
ch[x][0] = ch[x][1] = f[x] = cnt[x] = key[x] = siz[x] = 0;
}
inline int get( int x ) {
return ch[f[x]][1] == x;//0左1右
}
inline void update( int x ) {
if( x ) {
siz[x] = cnt[x];
if( ch[x][0] ) {
siz[x] += siz[ch[x][0]];
}
if( ch[x][1] ) {
siz[x] += siz[ch[x][1]];
}
}
return;
}
inline void rotate( int x ) {
int old = f[x] , oldf = f[old] , which = get( x );
ch[old][which] = ch[x][which ^ 1];
f[ch[old][which]] = old;
ch[x][which ^ 1] = old;
f[old] = x;
f[x] = oldf;
if( oldf ) {
ch[oldf][ch[oldf][1] == old] = x;
}
update( old );
update( x );
return;
}
inline void splay( int x ) {
for ( int fa ; ( fa = f[x] ) ; rotate(x) ) {
if( f[fa] ) {
rotate( ( get( x ) == get( fa ) ? fa : x ) );
}
}
root = x;
return;
}
inline void insert( int v ) {
if( root == 0 ) {
++sz;
root = sz;
ch[root][0] = ch[root][1] = f[root] = 0;
key[root] = v;
cnt[root] = siz[root] = 1;
return;
}
int cur = root , fa = 0;
while ( true ) {
if( key[cur] == v ) {
++cnt[cur];
update( cur );
update( fa );
splay( cur );//一边rotate,一边往上传值
return;
}
fa = cur;
cur = ch[cur][key[cur] < v];
if( cur == 0 ) {
++sz;
ch[sz][0] = ch[sz][1] = 0;
key[sz] = v;
siz[sz] = 1;
cnt[sz] = 1;
f[sz] = fa;
ch[fa][key[fa] < v] = sz;
update( fa );
splay( sz );
return;
}
}
}
inline int find( int v ) {
int ans = 0 , cur = root;
while ( true ) {
if( v < key[cur] ) {
cur = ch[cur][0];
} else {
ans += ( ch[cur][0] ? siz[ch[cur][0]] : 0 );
if( v == key[cur] ) {
splay( cur );
return ans + 1;
}
ans += cnt[cur];
cur = ch[cur][1];
}
}
}
inline int findth( int k ) {
int cur = root;
while ( true ) {
if( ch[cur][0] && k <= siz[ch[cur][0]] ) {
cur = ch[cur][0];
} else {
int tem = ( ch[cur][0] ? siz[ch[cur][0]] : 0 ) + cnt[cur];
if( k <= tem ) {
return key[cur];
}
k -= tem;
cur = ch[cur][1];
}
}
}
inline int pre() {
int cur = ch[root][0];
while ( ch[cur][1] ) {
cur = ch[cur][1];
}
return cur;
}
inline int nxt() {
int cur = ch[root][1];
while ( ch[cur][0] ) {
cur = ch[cur][0];
}
return cur;
}
inline void del( int v ) {
find( v );
if( cnt[root] > 1 ) {
--cnt[root];
update( root );
return;
}
if( !ch[root][0] && !ch[root][1] ) {
clear( root );
root = 0;
return;
}
if( !ch[root][0] ) {
int oldroot = root;
root = ch[root][1];
f[root] = 0;
clear( oldroot );
return;
} else if( !ch[root][1] ) {
int oldroot = root;
root = ch[root][0];
f[root] = 0;
clear( oldroot );
return;
}
int lpre = pre() , oldroot = root;
splay( lpre );
f[ch[oldroot][1]] = root;
ch[root][1] = ch[oldroot][1];
clear( oldroot );
update( root );
return;
}
int main() {
scanf("%d",&t);
int x;
while ( t-- ) {
scanf("%d",&opt);
switch( opt ) {
case 1 : {
scanf("%d",&x);
insert( x );
break;
}
case 2 : {
scanf("%d",&x);
del( x );
break;
}
case 3 : {
scanf("%d",&x);
printf("%d\n",find( x ));
break;
}
case 4 : {
scanf("%d",&x);
printf("%d\n",findth( x ));
break;
}
case 5 : {
scanf("%d",&x);
insert( x );
printf("%d\n",key[pre()]);
del( x );
break;
}
case 6 : {
scanf("%d",&x);
insert( x );
printf("%d\n",key[nxt()]);
del( x );
break;
}
}
}
return 0;
}
进阶操作一:区间翻转
splay除了满足一个普通平衡树的所有特点,在我们把key
关键字表示的值从结点数值的大小转变成结点的排名(在数组中排第几位),还可以进行区间的翻转和平移。
新增变量
pos[i]//初始序列中i号点对应的key值
tag[i]//标记,表示是否翻转
建树操作
作用:当key
关键字表示结点的排名时,平衡树的构建方法也发生了改变。
为了方便处理边界,如果我们要插入一个有n
个数的数组,我们要先插入key=1
和key=n+2
这两个结点来判断,对于剩下的点,它在原数组排第i
位,我们就将其赋值为key=i+1
。
赋值:
int main() {
for ( int i = 1 ; i <= n ; ++i ) {
pos[i + 1] = i;
}
pos[1] = -inf;
pos[n + 2] = inf;
}
建树(尽量平衡):
int build( int p , int l , int r ) {
if( l > r ) {
return 0;
}
int mid = ( l + r ) >> 1;
int cur = ++sz;
key[cur] = pos[mid];
f[cur] = p;
tag[cur] = 0;
ch[cur][0] = build( cur , l , mid - 1 );
ch[cur][1] = build( cur , mid + 1 , r );
update( cur );
return cur;
}
下传标记
作用:tag
标识区间的翻转,当该点的tag
为一,我们交换其左右儿子表示区间翻转,并把tag
传给该点的儿子。
inline void pushdown( int x ) {
if( x && tag[x] ) {
tag[ch[x][0]] ^= 1;
tag[ch[x][1]] ^= 1;
swap( ch[x][0] , ch[x][1] );
tag[x] = 0;
}
}
有目标的splay操作
作用:与普通的splay区别不是很大,加了一个目标goal
方便操作,结果是成为goal
的儿子(goal
为零时成为根节点)。
inline void splay( int x , int goal ) {
for ( int fa ; ( fa = f[x] ) != goal ; rotate(x) ) {
if( f[fa] != goal ) {
rotate( ( get( x ) == get( fa ) ? fa : x ) );
}
}
if( !goal ) {
root = x;
}
return;
}
turn操作
作用:实现区间翻转。
当我们输入l
和r
时,由于边界结点的存在,我们需要翻转的其实是区间
[
l
+
1
,
r
+
1
]
[l+1,r+1]
[l+1,r+1]。
基本思路就是先找出排位l
的点,把它splay到根,然后再找出排位r+2
的点,把它splay到根(排位l
的那个点)的儿子位置(右子树上),我们不难发现此时根节点右子树的左子树上的所有点均在区间
[
l
+
1
,
r
+
1
]
[l+1,r+1]
[l+1,r+1]内,且没有遗漏的。我们给根节点右子树的左子树打标记即可。
inline void turn( int l , int r ) {
l = findth( l );
r = findth( r + 2 );
splay( l , 0 );
splay( r , l );
pushdown( root );
tag[ch[ch[root][1]][0]] ^= 1;
return;
}
write操作
作用:利用二叉搜索树的性质来输出。
void write( int cur ) {
pushdown( cur );
if( ch[cur][0] ) {
write( ch[cur][0] );
}
if( key[cur] != -inf && key[cur] != inf ) {
printf("%d ",num[key[cur]]);
}
if( key[ch[cur][1]] ) {
write( ch[cur][1] );
}
}
完整代码
题目
只有区间翻转的完整代码,注意findth函数从原本的直接返回值变成了返回下标。
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#define ll long long
#define N 100005
using namespace std;
int n , m;
int num[N];
int inf = 1e9;
int sz = 0 , root = 0;
int ch[N][2] , f[N] , cnt[N] , key[N] , siz[N] , tag[N] , pos[N];
inline int get( int x ) {
return ch[f[x]][1] == x;//0左1右
}
inline void update( int x ) {
if( x ) {
siz[x] = cnt[x];
if( ch[x][0] ) {
siz[x] += siz[ch[x][0]];
}
if( ch[x][1] ) {
siz[x] += siz[ch[x][1]];
}
}
return;
}
inline void pushdown( int x ) {
if( x && tag[x] ) {
tag[ch[x][0]] ^= 1;
tag[ch[x][1]] ^= 1;
swap( ch[x][0] , ch[x][1] );
tag[x] = 0;
}
}
inline void rotate( int x ) {
int old = f[x] , oldf = f[old] , which = get( x );
pushdown( old );
pushdown( x );
ch[old][which] = ch[x][which ^ 1];
f[ch[old][which]] = old;
ch[x][which ^ 1] = old;
f[old] = x;
f[x] = oldf;
if( oldf ) {
ch[oldf][ch[oldf][1] == old] = x;
}
update( old );
update( x );
return;
}
inline void splay( int x , int goal ) {
for ( int fa ; ( fa = f[x] ) != goal ; rotate(x) ) {
if( f[fa] != goal ) {
rotate( ( get( x ) == get( fa ) ? fa : x ) );
}
}
if( !goal ) {
root = x;
}
return;
}
inline int findth( int k ) {
int cur = root;
while ( true ) {
pushdown( cur );
if( ch[cur][0] && k <= siz[ch[cur][0]] ) {
cur = ch[cur][0];
} else {
int tem = ( ch[cur][0] ? siz[ch[cur][0]] : 0 ) + cnt[cur];
if( k <= tem ) {
return cur;
}
k -= tem;
cur = ch[cur][1];
}
}
}
int build( int p , int l , int r ) {
if( l > r ) {
return 0;
}
int mid = ( l + r ) >> 1;
int cur = ++sz;
key[cur] = pos[mid];
f[cur] = p;
tag[cur] = 0;
siz[cur]++;
cnt[cur]++;
ch[cur][0] = build( cur , l , mid - 1 );
ch[cur][1] = build( cur , mid + 1 , r );
update( cur );
return cur;
}
inline void turn( int l , int r ) {
l = findth( l );
r = findth( r + 2 );
splay( l , 0 );
splay( r , l );
pushdown( root );
tag[ch[ch[root][1]][0]] ^= 1;
return;
}
void write( int cur ) {
pushdown( cur );
if( ch[cur][0] ) {
write( ch[cur][0] );
}
if( key[cur] != -inf && key[cur] != inf ) {
printf("%d ",num[key[cur]]);
}
if( key[ch[cur][1]] ) {
write( ch[cur][1] );
}
}
int main() {
scanf("%d%d",&n,&m);
for ( int i = 1 ; i <= n ; ++i ) {
num[i] = i;
pos[i + 1] = i;
}
pos[1] = -inf;
pos[n + 2] = inf;
root = build( 0 , 1 , n + 2 );
for ( int i = 1 ; i <= m ; ++i ) {
int x , y;
scanf("%d%d",&x,&y);
turn( x , y );
}
write( root );
return 0;
}
进阶操作二:区间翻转和区间交换
revolve操作
作用:实现区间平移(区间
[
l
,
r
]
[l,r]
[l,r]循环位移前进
k
k
k位)。
思路:其实区间平移可以借用区间翻转实现,具体表现为令
k
%
(
r
−
l
+
1
)
k \% (r-l+1)
k%(r−l+1),则我们可以依次翻转区间
[
l
,
r
]
[l,r]
[l,r],
[
l
,
l
+
k
−
1
]
[l,l+k-1]
[l,l+k−1]和
[
l
+
k
,
r
]
[l+k,r]
[l+k,r],即可实现该操作。
inline void revolve( int x , int y , int z ) {
int len = ( y - x + 1 );
z %= len;
if( !z ) {
return;
}
int mid = x + z - 1;
turn( x , y );//turn内置了可找到对应排位的点的函数
turn( x , mid );
turn( mid + 1 , y );
return;
}
change操作
作用:实现两个区间交换位置。
思路:其实区间交换也可以借用区间翻转实现,具体表现为如果要交换区间
[
a
,
b
]
[a,b]
[a,b]和
[
c
,
d
]
[c,d]
[c,d],其中
a
≤
b
<
c
≤
d
a \le b \lt c \le d
a≤b<c≤d,我们可以依次翻转区间
[
a
,
d
]
[a,d]
[a,d],
[
a
,
a
+
d
−
c
]
[a,a+d-c]
[a,a+d−c],
[
d
−
b
+
a
,
d
]
[d-b+a,d]
[d−b+a,d]和
[
a
+
d
−
c
+
1
,
d
−
b
+
a
−
1
]
[a+d-c+1,d-b+a-1]
[a+d−c+1,d−b+a−1]即可。
inline void change( int a , int b , int c , int d ) {
if( b >= c ) {
return;
}
turn( a , d );//turn内置了可找到对应排位的点的函数
turn( a , a + d - c );
turn( d - b + a , d );
turn( a + d - c + 1 , d - b + a - 1 );
return;
}
进阶操作三:以排名为关键字时的插入删除
插入删除
插入删除略有改变,且要经常pushdown
。
inline void insert( int x , int y ) {
x = findth( x + 1 );
pushdown( x );
splay( x , 0 );
++sz;
num[sz] = y;
tag[sz] = 0;
add[sz] = 0;
cnt[sz] = 1;
key[sz] = sz;
ch[sz][1] = ch[root][1];
f[ch[root][1]] = sz;
ch[root][1] = sz;
f[sz] = root;
update( sz );
update( root );
return;
}
inline int pre() {
pushdown( root );
int cur = ch[root][0];
pushdown( cur );
while ( ch[cur][1] ) {
cur = ch[cur][1];
pushdown( cur );
}
return cur;
}
inline void del( int x ) {
x = findth( x + 1 );
pushdown( x );
splay( x , 0 );
int lpre = pre() , oldroot = root;
splay( lpre , 0 );
f[ch[oldroot][1]] = root;
ch[root][1] = ch[oldroot][1];
clear( oldroot );
update( root );
return;
}
综合了前三个进阶操作的题
题目
里面区间加,区间取最小值的操作中取区间的方式参考turn
操作。
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <cstdlib>
#define ll long long
#define N 200015
using namespace std;
int n , m;
int sz = 0 , root = 0;
int inf = 1e9;
int pos[N] , a[N];
int ch[N][2] , f[N] , cnt[N] , key[N] , siz[N] , add[N] , tag[N] , num[N] , minn[N];
//num存数值,add存加数,minn存最小值,key还是存排名但是本题中可以不用了。
inline void clear( int x ) {
ch[x][0] = ch[x][1] = f[x] = cnt[x] = key[x] = siz[x] = add[x] = tag[x] = num[x] = minn[x] = 0;
}
inline int get( int x ) {
return ch[f[x]][1] == x;//0左1右
}
inline void update( int x ) {
if( x ) {
siz[x] = cnt[x];
if( ch[x][0] ) {
siz[x] += siz[ch[x][0]];
}
if( ch[x][1] ) {
siz[x] += siz[ch[x][1]];
}
minn[x] = num[x];
if( ch[x][0] ) {
minn[x] = min( minn[x] , minn[ch[x][0]] );
}
if( ch[x][1] ) {
minn[x] = min( minn[x] , minn[ch[x][1]] );
}
}
return;
}
inline void pushadd( int x , int y ) {
if( !x || !y ) {
return;
}
add[x] += y;
num[x] += y;
minn[x] += y;
return;
}
inline void pushdown( int x ) {
if( x && add[x] ) {
pushadd( ch[x][0] , add[x] );
pushadd( ch[x][1] , add[x] );
add[x] = 0;
}
if( x && tag[x] ) {
tag[ch[x][0]] ^= 1;
tag[ch[x][1]] ^= 1;
swap( ch[x][0] , ch[x][1] );
tag[x] = 0;
}
return;
}
inline void rotate( int x ) {
int old = f[x] , oldf = f[old] , which = get( x );
pushdown( old );
pushdown( x );
ch[old][which] = ch[x][which ^ 1];
f[ch[old][which]] = old;
ch[x][which ^ 1] = old;
f[old] = x;
f[x] = oldf;
if( oldf ) {
ch[oldf][ch[oldf][1] == old] = x;
}
update( old );
update( x );
return;
}
inline void splay( int x ,int goal ) {
for ( int fa ; ( fa = f[x] ) != goal ; rotate(x) ) {
if( f[fa] != goal ) {
rotate( ( get( x ) == get( fa ) ? fa : x ) );
}
}
if( !goal ) {
root = x;
}
pushdown( x );
return;
}
inline int findth( int k ) {
int cur = root;
while ( true ) {
pushdown( cur );
if( ch[cur][0] && k <= siz[ch[cur][0]] ) {
cur = ch[cur][0];
} else {
int tem = ( ch[cur][0] ? siz[ch[cur][0]] : 0 ) + cnt[cur];
if( k <= tem ) {
return cur;
}
k -= tem;
cur = ch[cur][1];
}
}
}
inline void insert( int x , int y ) {
x = findth( x + 1 );
pushdown( x );
splay( x , 0 );
++sz;
num[sz] = y;
tag[sz] = 0;
add[sz] = 0;
cnt[sz] = 1;
key[sz] = sz;
ch[sz][1] = ch[root][1];
f[ch[root][1]] = sz;
ch[root][1] = sz;
f[sz] = root;
update( sz );
update( root );
return;
}
inline int pre() {
pushdown( root );
int cur = ch[root][0];
pushdown( cur );
while ( ch[cur][1] ) {
cur = ch[cur][1];
pushdown( cur );
}
return cur;
}
inline void del( int x ) {
x = findth( x + 1 );
pushdown( x );
splay( x , 0 );
int lpre = pre() , oldroot = root;
splay( lpre , 0 );
f[ch[oldroot][1]] = root;
ch[root][1] = ch[oldroot][1];
clear( oldroot );
update( root );
return;
}
inline void turn( int l , int r ) {
if( l == r ) {
return;
}
l = findth( l );
r = findth( r + 2 );
pushdown( l );
pushdown( r );
splay( l , 0 );
splay( r , l );
pushdown( root );
tag[ch[ch[root][1]][0]] ^= 1;
return;
}
inline void getmin( int l , int r ) {
if( l == r ) {
l = findth( l + 1 );
pushdown( l );
printf("%d\n",minn[l]);
return;
}
l = findth( l );
r = findth( r + 2 );
pushdown( l );
pushdown( r );
splay( l , 0 );
splay( r , l );
printf("%d\n",minn[ch[ch[root][1]][0]]);
return;
}
inline void doadd( int x , int y , int z ) {
if( x == y ) {
x = findth( x + 1 );
pushdown( x );
num[x] += z;
return;
}
x = findth( x );
y = findth( y + 2 );
pushdown( x );
pushdown( y );
splay( x , 0 );
splay( y , x );
pushadd( ch[ch[root][1]][0] , z );
return;
}
inline void revolve( int x , int y , int z ) {
int len = ( y - x + 1 );
z %= len;
if( !z ) {
return;
}
int mid = x + z - 1;
turn( x , y );
turn( x , mid );
turn( mid + 1 , y );
return;
}
int build( int p , int l , int r ) {
if( l > r ) {
return 0;
}
int mid = ( l + r ) >> 1;
int cur = ++sz;
key[cur] = pos[mid];
if( pos[mid] != -inf && pos[mid] != inf ) {
num[cur] = minn[cur] = a[pos[mid]];
} else {
num[cur] = minn[cur] = inf;
}
f[cur] = p;
tag[cur] = 0;
add[cur] = 0;
siz[cur]++;
cnt[cur]++;
ch[cur][0] = build( cur , l , mid - 1 );
ch[cur][1] = build( cur , mid + 1 , r );
update( cur );
return cur;
}
void write( int cur ) {//测试用
pushdown( cur );
if( ch[cur][0] ) {
write( ch[cur][0] );
}
if( key[cur] != -inf && key[cur] != inf ) {
printf("%d ",num[cur]);
}
if( key[ch[cur][1]] ) {
write( ch[cur][1] );
}
}
int main() {
scanf("%d",&n);
pos[1] = -inf;
pos[n + 2] = inf;
for ( int i = 1 ; i <= n ; ++i ) {
pos[i + 1] = i;
scanf("%d",&a[i]);
}
root = build( 0 , 1 , n + 2 );
scanf("%d",&m);
int x , y , z;
char str[10];
while ( m-- ) {
scanf("%s",str);
if( str[0] == 'A' ) {
scanf("%d%d%d",&x,&y,&z);
doadd( x , y , z );
}
if( str[0] == 'I' ) {
scanf("%d%d",&x,&y);
insert( x , y );
}
if( str[0] == 'D' ) {
scanf("%d",&x);
del( x );
}
if( str[0] == 'M' ) {
scanf("%d%d",&x,&y);
getmin( x , y );
}
if( str[0] == 'R' && str[3] == 'E' ) {
scanf("%d%d",&x,&y);
turn( x , y );
}
if( str[0] == 'R' && str[3] == 'O' ) {
scanf("%d%d%d",&x,&y,&z);
revolve( x , y , z );
}
}
return 0;
}
进阶操作四:任意分裂和合并子树
splay还有一个特点,就是具有对于子树的任意分裂和合并的功能。
split操作
作用:单独拿出一个
[
x
,
x
+
l
e
n
−
1
]
[x,x+len-1]
[x,x+len−1]的区间,当作子树操作。
由于边界结点的存在,我们需要关注的区间是
[
x
+
1
,
x
+
l
e
n
]
[x+1,x+len]
[x+1,x+len]。
基本思路就是先找出排位x
的点,把它splay到根,然后再找出排位x+len+1
的点,把它splay到根(排位x
的那个点)的儿子位置(右子树上),我们不难发现此时根节点右子树的左子树上的所有点均在区间
[
x
+
1
,
x
+
l
e
n
]
[x+1,x+len]
[x+1,x+len]内,且没有遗漏的,我们返回根节点右子树的左子树的根的编号,即可实现分离出子树的操作。
inline int split( int x , int len ) {
int l = findth( x );
int r = findth( x + len + 1 );
pushdown( l );
pushdown( r );
splay( l , 0 );
splay( r , l );
pushdown( root );
return ch[ch[root][1]][0];
}
join操作
作用:合并两棵子树,我们这里是输入一连串的树,先让那一连串的数自己建树,然后再把这棵树与原来的树合并。
void build( int p , int l , int r ) {
if( l > r ) {
return;
}
int mid = ( l + r ) >> 1;
int cur = node[mid];
int fa = node[p];
num[cur] = a[mid];
f[cur] = fa;
tag[cur] = flag[cur] = 0;
cnt[cur] = 1;
siz[cur] = 1;
if( l == r ) {
sum[cur] = a[mid];
if( a[mid] >= 0 ) {
lmaxn[cur] = rmaxn[cur] = maxn[cur] = a[mid];
} else {
lmaxn[cur] = rmaxn[cur] = 0;
maxn[cur] = a[mid];
}
} else {
build( mid , l , mid - 1 );
build( mid , mid + 1 , r );
}
update( cur );
ch[fa][mid >= p] = cur;
return;
}
inline void join( int x , int len ) {
for ( int i = 1 ; i <= len ; ++i ) {
scanf("%d",&a[i]);
}
for ( int i = 1 ; i <= len ; ++i ) {
if( q.size() ) {
node[i] = q.front();
q.pop();
} else {
node[i] = ++sz;
}
}
build( 0 , 1 , len );
int c = node[( 1 + len ) >> 1];
int a = findth( x + 1 );
int b = findth( x + 2 );
splay( a , 0 );
splay( b , a );
f[c] = b;
ch[b][0] = c;
update( c );
update( b );
update( a );
return;
}
recycle操作
作用:使表示点的下标可重复利用,用队列辅助处理。
inline void recycle( int x ) {
if( !x ) {
return;
}
recycle( ch[x][0] );
recycle( ch[x][1] );
q.push(x);
f[x] = ch[x][0] = ch[x][1] = 0;
flag[x] = tag[x] = 0;
lmaxn[x] = rmaxn[x] = maxn[x] = 0;
return;
}
完整代码如下:
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <queue>
#define ll long long
#define N 1000015
using namespace std;
int n , m;
int inf = 1e9;
int root , sz;
queue<int> q;
int ch[N][2] , f[N] , cnt[N] , siz[N] , tag[N] , flag[N];
int num[N] , sum[N] , maxn[N] , lmaxn[N] , rmaxn[N];
int a[N] , node[N];
inline int get( int x ) {
return ch[f[x]][1] == x;//0左1右
}
inline void update( int x ) {
sum[x] = num[x] + sum[ch[x][0]] + sum[ch[x][1]];
siz[x] = cnt[x] + siz[ch[x][0]] + siz[ch[x][1]];
maxn[x] = max( maxn[ch[x][0]] , maxn[ch[x][1]] );
maxn[x] = max( maxn[x] , rmaxn[ch[x][0]] + num[x] + lmaxn[ch[x][1]] );
lmaxn[x] = max( lmaxn[ch[x][0]] , sum[ch[x][0]] + num[x] + lmaxn[ch[x][1]] );
rmaxn[x] = max( rmaxn[ch[x][1]] , sum[ch[x][1]] + num[x] + rmaxn[ch[x][0]] );
return;
}
inline void pushdown( int x ) {
if( flag[x] ) {
tag[x] = flag[x] = 0;
if( ch[x][0] ) {
flag[ch[x][0]] = 1;
num[ch[x][0]] = num[x];
sum[ch[x][0]] = num[x] * siz[ch[x][0]];
}
if( ch[x][1] ) {
flag[ch[x][1]] = 1;
num[ch[x][1]] = num[x];
sum[ch[x][1]] = num[x] * siz[ch[x][1]];
}
if( num[x] >= 0 ) {
if( ch[x][0] ) {
lmaxn[ch[x][0]] = rmaxn[ch[x][0]] = maxn[ch[x][0]] = sum[ch[x][0]];
}
if( ch[x][1] ) {
lmaxn[ch[x][1]] = rmaxn[ch[x][1]] = maxn[ch[x][1]] = sum[ch[x][1]];
}
} else {
if( ch[x][0] ) {
lmaxn[ch[x][0]] = rmaxn[ch[x][0]] = 0;
maxn[ch[x][0]] = num[x];
}
if( ch[x][1] ) {
lmaxn[ch[x][1]] = rmaxn[ch[x][1]] = 0;
maxn[ch[x][1]] = num[x];
}
}
}
if( tag[x] ) {
tag[ch[x][0]] ^= 1;
tag[ch[x][1]] ^= 1;
tag[x] ^= 1;
swap( lmaxn[ch[x][0]] , rmaxn[ch[x][0]] );
swap( lmaxn[ch[x][1]] , rmaxn[ch[x][1]] );
swap( ch[ch[x][0]][0] , ch[ch[x][0]][1] );
swap( ch[ch[x][1]][0] , ch[ch[x][1]][1] );
}
}
inline void rotate( int x ) {
int old = f[x] , oldf = f[old] , which = get( x );
pushdown( old );
pushdown( x );
ch[old][which] = ch[x][which ^ 1];
f[ch[old][which]] = old;
ch[x][which ^ 1] = old;
f[old] = x;
f[x] = oldf;
if( oldf ) {
ch[oldf][ch[oldf][1] == old] = x;
}
update( old );
update( x );
return;
}
inline void splay( int x ,int goal ) {
for ( int fa ; ( fa = f[x] ) != goal ; rotate(x) ) {
if( f[fa] != goal ) {
rotate( ( get( x ) == get( fa ) ? fa : x ) );
}
}
if( !goal ) {
root = x;
}
pushdown( x );
return;
}
inline int findth( int k ) {
int cur = root;
while ( true ) {
pushdown( cur );
if( ch[cur][0] && k <= siz[ch[cur][0]] ) {
cur = ch[cur][0];
} else {
int tem = ( ch[cur][0] ? siz[ch[cur][0]] : 0 ) + cnt[cur];
if( k <= tem ) {
return cur;
}
k -= tem;
cur = ch[cur][1];
}
}
}
inline int split( int x , int len ) {
int l = findth( x );
int r = findth( x + len + 1 );
pushdown( l );
pushdown( r );
splay( l , 0 );
splay( r , l );
pushdown( root );
return ch[ch[root][1]][0];
}
inline void recycle( int x ) {
if( !x ) {
return;
}
recycle( ch[x][0] );
recycle( ch[x][1] );
q.push(x);
f[x] = ch[x][0] = ch[x][1] = 0;
flag[x] = tag[x] = 0;
lmaxn[x] = rmaxn[x] = maxn[x] = 0;
return;
}
inline void query( int x , int len ) {
int k = split( x , len );
printf("%d\n",sum[k]);
return;
}
inline void modify( int x , int len , int val ) {
int a = split( x , len );
int b = f[a];
num[a] = val;
flag[a] = 1;
sum[a] = siz[a] * val;
if( val >= 0 ) {
lmaxn[a] = rmaxn[a] = maxn[a] = sum[a];
} else {
lmaxn[a] = rmaxn[a] = 0;
maxn[a] = val;
}
update( b );
update( f[b] );
return;
}
inline void turn( int x , int len ) {
if( len == 1 ) {
return;
}
int a = split( x , len );
int b = f[a];
if( !flag[a] ) {
tag[a] ^= 1;
swap( ch[a][0] , ch[a][1] );
swap( lmaxn[a] , rmaxn[a] );
update( b );
update( f[b] );
}
return;
}
inline void erase( int x , int len ) {
int a = split( x , len );
int b = f[a];
recycle( a );
ch[b][0] = 0;
update( b );
update( f[b] );
return;
}
void build( int p , int l , int r ) {
if( l > r ) {
return;
}
int mid = ( l + r ) >> 1;
int cur = node[mid];
int fa = node[p];
num[cur] = a[mid];
f[cur] = fa;
tag[cur] = flag[cur] = 0;
cnt[cur] = 1;
siz[cur] = 1;
if( l == r ) {
sum[cur] = a[mid];
if( a[mid] >= 0 ) {
lmaxn[cur] = rmaxn[cur] = maxn[cur] = a[mid];
} else {
lmaxn[cur] = rmaxn[cur] = 0;
maxn[cur] = a[mid];
}
} else {
build( mid , l , mid - 1 );
build( mid , mid + 1 , r );
}
update( cur );
ch[fa][mid >= p] = cur;
return;
}
inline void join( int x , int len ) {
for ( int i = 1 ; i <= len ; ++i ) {
scanf("%d",&a[i]);
}
for ( int i = 1 ; i <= len ; ++i ) {
if( q.size() ) {
node[i] = q.front();
q.pop();
} else {
node[i] = ++sz;
}
}
build( 0 , 1 , len );
int c = node[( 1 + len ) >> 1];
int a = findth( x + 1 );
int b = findth( x + 2 );
splay( a , 0 );
splay( b , a );
f[c] = b;
ch[b][0] = c;
update( c );
update( b );
update( a );
return;
}
int main() {
scanf("%d%d",&n,&m);
a[1] = -inf;
a[n + 2] = -inf;
maxn[0] = -inf;
for ( int i = 1 ; i <= n ; ++i ) {
scanf("%d",&a[i + 1]);
}
for ( int i = 1 ; i <= n + 2 ; ++i ) {
node[i] = i;
}
build( 0 , 1 , n + 2 );
root = ( 1 + n + 2 ) >> 1;
sz = n + 2;
char opt[15];
while ( m-- ) {
scanf("%s",opt);
int x , len , val;
if( opt[0] != 'M' || opt[2] != 'X' ) {
scanf("%d%d",&x,&len);
}
if( opt[0] == 'I' ) {
join( x , len );
}
if( opt[0] == 'D' ) {
erase( x , len );
}
if( opt[0] == 'M' ) {
if( opt[2] == 'X' ) {
printf("%d\n",maxn[root]);
} else {
scanf("%d",&val);
modify( x , len , val );
}
}
if( opt[0] == 'R' ) {
turn( x , len );
}
if( opt[0] == 'G' ) {
query( x , len );
}
}
return 0;
}