1 是什么?
左偏树是一个在堆之上的数据结构,有堆的性质,在堆合并的时间上更有优势。
对于普通的堆来说,我们合并两个堆的方式就是一个结点一个结点地插入,复杂度很高。
但是对于左偏树来说,复杂度仍然是:
O
(
l
o
g
2
n
1
+
l
o
g
2
n
2
)
O( log_2n_1+log_2n_2)
O(log2n1+log2n2)
左偏树还有一个左偏的性质,即:
我们知道,在树上的算法,深度决定了这个算法的时间性能。而这个性质可以保证我们的树深度不会太深甚至退化成链,保证了一个 n n n个结点的左偏树距离最大为: l o g 2 ( n + 1 ) − 1 log_2(n+1)-1 log2(n+1)−1
2 基本思想
这里借用一下@远航之曲 的图片。
int merge(int x,int y){
//对于为空的情况
if(!x){
return y;
}
if(!y){
return x;
}
//这里以小根堆为例
if(nodes[x].v>nodes[y].v || (nodes[x].v==nodes[y].v && x>y)){
std::swap(x,y);//保证x的值小于y的
}
nodes[x].rs=merge(nodes[x].rs,y);
nodes[nodes[x].rs].fa=x;
if(nodes[nodes[x].rs].dis>nodes[nodes[x].ls].dis){
std::swap(nodes[x].ls,nodes[x].rs);
//如果右边的距离大于左边的,就不符合左偏树的性质了,我们要交换
}
nodes[x].dis=(!nodes[x].rs)?0:nodes[nodes[x].rs].dis+1;
//更新距离
return x;
}
3其他功能
3-1 插入结点
对于插入一个结点来说,我们为了减少代码量,采用将该结点看作一个只有一个点的左偏树的方法,进行合并,调用merge函数即可,这里不再多赘述。
3-2 删除结点
我们删掉这个节点之后,再将它的左右子树合并,大功告成。
void del(int x){
int lson=nodes[x].ls;
int rson=nodes[x].rs;
nodes[x].v=-inf;//标记删除
nodes[x].fa=0;
nodes[x].ls=0;
nodes[x].rs=0;
nodes[x].dis=0;
nodes[lson].fa=lson;
nodes[rson].fa=rson;
merge(lson,rson);//合并左右结点
}
3-3 判断两个点是否在一个堆中
对于这个操作,我们使用并查集的方法。
int find(int x){
while(x!=nodes[x].fa){
x=nodes[x].fa;
}
return x;
}
4 整体代码
这里以洛谷P3377为例。
#include <cstdio>
#include <iostream>
#define maxn 100010
#define inf 0x3f3f3f3f
int n,m;
struct node{
int v,ls,rs,dis,fa;
}nodes[maxn];
int find(int x){
while(x!=nodes[x].fa){
x=nodes[x].fa;
}
return x;
}
int merge(int x,int y){
if(!x){
return y;
}
if(!y){
return x;
}
if(nodes[x].v>nodes[y].v || (nodes[x].v==nodes[y].v && x>y)){
std::swap(x,y);
}
nodes[x].rs=merge(nodes[x].rs,y);
nodes[nodes[x].rs].fa=x;
if(nodes[nodes[x].rs].dis>nodes[nodes[x].ls].dis){
std::swap(nodes[x].ls,nodes[x].rs);
}
nodes[x].dis=(!nodes[x].rs)?0:nodes[nodes[x].rs].dis+1;
return x;
}
void del(int x){
int lson=nodes[x].ls;
int rson=nodes[x].rs;
nodes[x].v=-inf;
nodes[x].fa=0;
nodes[x].ls=0;
nodes[x].rs=0;
nodes[x].dis=0;
nodes[lson].fa=lson;
nodes[rson].fa=rson;
merge(lson,rson);
}
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++){
scanf("%d",&nodes[i].v);
nodes[i].fa=i;
}
for(int i=1;i<=m;i++){
int sign;
scanf("%d",&sign);
if(sign==1){
int x,y;
scanf("%d%d",&x,&y);
int fx=find(x);
int fy=find(y);
if(fx==fy || nodes[x].v==-inf || nodes[x].v==-inf){
continue;
}
merge(fx,fy);
}
else if(sign==2){
int x;
scanf("%d",&x);
if(nodes[x].v==-inf){
printf("-1\n");
}
else{
int fx=find(x);
printf("%d\n",nodes[fx].v);
del(fx);
}
}
}
return 0;
}