前置知识:主席树,按秩合并并查集
首先这题不能用压缩路径并查集,我也没搞清楚,所以使用按秩合并。
按秩合并
我更喜欢叫按深度合并,就是深度小的链接到深度大的树上,至于为什么,如图,如果反过来,我们的最大深度会继续加深,若干次操作后深度会很巨大,每次遍历都要跑很多层,然而小并大就能有效减少层数
为了实现这个操作,我们可以定义寻找父节点的函数,若当前的节点父节点不为自己就递归调用,为此,我们有可以定义函数,根据编号找到某节点在树的哪个位置代码如下
int sfind(int x,int l,int r,int s)
{
if(l==r) return x;
int mid=(l+r)/2;
if(s<=mid) return sfind(tree[x].l,l,mid,s);
else return sfind(tree[x].r,mid+1,r,s);
}
int que(int x,int s)
{
int tem=sfind(x,1,m,s);
if(tree[tem].fa==s) return tem;
return que(x,tree[tem].fa);
}
我们就完成了寻找父节点的部分
合并节点部分,我们需要比较深度,然后可持久化,注意如果深度一样的话就需要再将深度再加一,同时这里如果在加一操作时不新建版本的话会跑很久(我也不清楚为什么)
代码如下
int x=que(root[i],b),y=que(root[i],c);
if(tree[x].fa!=tree[y].fa)
{
if(tree[x].dep>tree[y].dep) swap(x,y);
root[i]=com(root[i],1,m,tree[x].fa,tree[y].fa);
if(tree[x].dep==tree[y].dep)
root[i]=up(root[i],1,m,tree[y].fa);
}
int com(int x,int l,int r,int a,int b)
{
int n=t++;
tree[n]=tree[x];
if(l==r)
{
tree[n].fa=b;
return n;
}
int mid=(l+r)/2;
if(a<=mid) tree[n].l=com(tree[n].l,l,mid,a,b);
else tree[n].r=com(tree[n].r,mid+1,r,a,b);
return n;
}
int up(int x,int l,int r,int s)
{
int n=t++;
tree[n]=tree[x];
if(l==r)
{
tree[n].dep++;
return n;
}
int mid=(l+r)/2;
if(s<=mid) tree[n].l=up(tree[n].l,l,mid,s);
else tree[n].r=up(tree[n].r,mid+1,r,s);
return n;
}
最后此题就剩一些输入输出细节了
完整代码
#include <bits/stdc++.h>
using namespace std;
#define N 300000
typedef struct node
{
int l,r,fa,dep;
}node;
node tree[N*30];
int root[N*30];
int t=0,m,q;
int build(int l,int r)//建树
{
int n=t++;
if(l==r)
{
tree[n].fa=l;
tree[n].dep=0;
return n;
}
int mid=(l+r)/2;
tree[n].l=build(l,mid);
tree[n].r=build(mid+1,r);
return n;
}
int sfind(int x,int l,int r,int s)//节点标号,范围,编号
{
if(l==r) return x;
int mid=(l+r)/2;
if(s<=mid) return sfind(tree[x].l,l,mid,s);
else return sfind(tree[x].r,mid+1,r,s);
}
int que(int x,int s)//节点标号,编号//递归寻找父节点
{
int tem=sfind(x,1,m,s);
if(tree[tem].fa==s) return tem;
return que(x,tree[tem].fa);
}
int com(int x,int l,int r,int a,int b)//节点标号,范围,两个祖先//建新版本
{
int n=t++;
tree[n]=tree[x];
if(l==r)
{
tree[n].fa=b;
return n;
}
int mid=(l+r)/2;
if(a<=mid) tree[n].l=com(tree[n].l,l,mid,a,b);
else tree[n].r=com(tree[n].r,mid+1,r,a,b);
return n;
}
int up(int x,int l,int r,int s)//节点标号,范围,编号//建新版本并加深度
{
int n=t++;
tree[n]=tree[x];
if(l==r)
{
tree[n].dep++;
return n;
}
int mid=(l+r)/2;
if(s<=mid) tree[n].l=up(tree[n].l,l,mid,s);
else tree[n].r=up(tree[n].r,mid+1,r,s);
return n;
}
int main()
{
scanf("%d%d",&m,&q);
root[0]=build(1,m);
for(int i=1;i<=q;i++)
{
int a,b,c;
scanf("%d",&a);
if(a==1)
{
scanf("%d%d",&b,&c);
root[i]=root[i-1];
int x=que(root[i],b),y=que(root[i],c);
if(tree[x].fa!=tree[y].fa)
{
if(tree[x].dep>tree[y].dep) swap(x,y);//深度小并到深度大
root[i]=com(root[i],1,m,tree[x].fa,tree[y].fa);
if(tree[x].dep==tree[y].dep)
root[i]=up(root[i],1,m,tree[y].fa);
}
}
else if(a==2)
{
scanf("%d",&b);
root[i]=root[b];
}
else if(a==3)
{
scanf("%d%d",&b,&c);
root[i]=root[i-1];
int x=que(root[i],b),y=que(root[i],c);
if(tree[x].fa==tree[y].fa) printf("1\n");
else printf("0\n");
}
}
return 0;
}