这道题明显树剖要好想一点
所以我来说一下lct
需要你对lct理解很深,首先是将1操作转化为Access,其次是3操作的高效求解。
我们模拟一下过程,可以利用线段树维护
3操作可以维护子树信息,但结合线段树复杂度会达到logn^3,因为是一颗不改变的树,所以dfs序
有一个地方理解不深的朋友会做错,就是Access的时候对于区间端点的选择,每次要找到根节点才可以
20分钟写完代码后花费一上午检查,知道刚才才明白上行的错误,教训:时常提醒自己算法的原理,对于自己不熟练的地方一定要想清楚了再提交
#include<iostream>
#include<iomanip>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#define lson (o<<1)
#define rson (o<<1)|1
using namespace std;
const int maxn=1e5+7;
int n,m,s[maxn],nxt[maxn*2],to[maxn*2],head[maxn],dep[maxn],
t[maxn][2],din[maxn],dout[maxn],dfst=0,fa[maxn][22],f[maxn],
mx[maxn*4],cnt=0,a[maxn],laz[maxn*4];
void add_edge(int x,int y)
{
nxt[++cnt]=head[x];head[x]=cnt;to[cnt]=y;
nxt[++cnt]=head[y];head[y]=cnt;to[cnt]=x;
}
void dfs(int x,int fat)
{
fa[x][0]=fat;din[x]=++dfst;dep[x]=dep[fat]+1;a[din[x]]=a[din[fat]]+1;
for(int i=1;i<=20;++i) fa[x][i]=fa[fa[x][i-1]][i-1];
for(int i=head[x];i;i=nxt[i])
{
int u=to[i];if(u==fat) continue;
f[u]=x;dfs(u,x);
}
dout[x]=dfst;
}
int find_lca(int x,int y)
{
if(dep[x]<dep[y]) swap(x,y);
int l=dep[x]-dep[y];
for(int i=0;i<=20;++i) if(l&(1<<i)) x=fa[x][i];
if(x==y) return x;
for(int i=20;i>=0;--i)
if(fa[x][i]!=fa[y][i])
{x=fa[x][i];y=fa[y][i];}
return fa[x][0];
}
void up(int o)
{
mx[o]=max(mx[lson],mx[rson]);
}
void downit(int o)
{
if(!laz[o]) return ;
laz[lson]+=laz[o];laz[rson]+=laz[o];
mx[lson]+=laz[o];mx[rson]+=laz[o];laz[o]=0;
}
void build_tree(int o,int l,int r)
{
if(l==r) mx[o]=a[l];
else
{
int mid=(l+r)/2;
build_tree(lson,l,mid);build_tree(rson,mid+1,r);up(o);
}
}
void add_tree(int o,int l,int r,int ll,int rr,int v)
{
if(ll<=l&&rr>=r) {mx[o]+=v;laz[o]+=v;return ;}
int mid=(l+r)/2;downit(o);
if(ll<=mid) add_tree(lson,l,mid,ll,rr,v);
if(rr>mid) add_tree(rson,mid+1,r,ll,rr,v);
up(o);
}
int find_max(int o,int l,int r,int ll,int rr)
{
if(ll<=l&&rr>=r) return mx[o];
int mid=(l+r)/2;int maxx=-1;downit(o);
if(ll<=mid) maxx=find_max(lson,l,mid,ll,rr);
if(rr>mid) maxx=max(maxx,find_max(rson,mid+1,r,ll,rr));
return maxx;
}
int isroot(int x)
{
return t[f[x]][0]!=x&&t[f[x]][1]!=x;
}
void rotate(int x)
{
int y=f[x],z=f[y],d=t[y][1]==x;
if(!isroot(y)) t[z][t[z][1]==y]=x;
f[t[x][d^1]]=y;t[y][d]=t[x][d^1];t[x][d^1]=y;
f[y]=x;f[x]=z;
}
void splay(int x)
{
while(!isroot(x))
{
int y=f[x],z=f[y];
if(!isroot(y)) rotate(((t[z][0]==y)^(t[y][0]==x))?x:y);
rotate(x);
}
}
int solve(int x,int todep)
{
int now=0;
while(dep[x]>=todep)
{
splay(x);x=f[x];now++;
}
return now;
}
void Access(int x)
{
for(int y=0;x;y=x,x=f[x])
{
splay(x);int z=x;
while(t[z][0]) z=t[z][0];
splay(z);add_tree(1,1,n,din[z],dout[z],-1);
splay(x);
if(t[x][1])
{
z=t[x][1];while(z&&t[z][0]) z=t[z][0];
add_tree(1,1,n,din[z],dout[z],1);
}
t[x][1]=y;
}
add_tree(1,1,n,1,n,1);
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<n;++i)
{
int x,y;scanf("%d%d",&x,&y);
add_edge(x,y);
}
dfs(1,0);cnt=0;build_tree(1,1,n);
for(int i=1;i<=m;++i)
{
int opt,x,y;scanf("%d%d",&opt,&x);
if(opt==1) Access(x);
if(opt==2)
{
scanf("%d",&y);int lca=find_lca(x,y);
int ans=solve(x,dep[lca])+solve(y,dep[lca]);
printf("%d\n",ans-1);
}
if(opt==3) printf("%d\n",find_max(1,1,n,din[x],dout[x]));
}
return 0;
}