题意
一棵 n 个节点的树,有两种共
q 个操作:
①.第 x 个操作:将某一个点赋值为x
②.询问 u 到v 的路径上有多少个点,并求出路径上已赋值的点中,值小于等于 W 的点数
解法
树链剖分+主席树:
这道题的关键是如何计数。方法之一是离线做,网上大部分的做法都是如此,就不再赘述。另一种方法就是在线做,考虑用树状数组套用主席树然后再用树链剖分求得答案
我们以dfs 序开一个树状数组,树状数组的每一个节点表示一棵值域线段树(值域为 【1,q】 ),可以称之为可持久化树状数组(具体可以见代码)
对于操作①,我们可以直接用类似于树状数组的单点修改的方法,更改一系列的节点,共 logn 个,每一个节点都是一棵树,有 logn 层,所有更新一次需要 log2n 的复杂度
对于操作②,答案1很容易得出,为 depu+depv−2∗depLCAu,v+1 ,直接用树剖求 LCA 比倍增块,然后就是求答案2:
采用树剖和树状数组区间求和的方法。我们利用树剖跳链求出一段一段的区间,然后在树状数组上求出各段的贡献,即树状数组 【l,r】 内的点所代表的线段树的 【1,W】 的和
复杂度
O( nlog2n ),常数有点大……
代码
#include<iostream>
#include<cstdlib>
#include<cstdio>
#define Rint register int
using namespace std;
const int MAXN=200010;
struct node
{
int next,to;
}e[MAXN];
int head[MAXN],num;
int ls[MAXN*100],rs[MAXN*100];
int dfn[MAXN],top[MAXN];
int dep[MAXN],son[MAXN];
int siz[MAXN],rt[MAXN];
int sum[MAXN*100];
int f[MAXN];
int n,q,cnt,W;
int srt,tot;
int lowbit(int x)
{
return x&(-x);
}
void add(int u,int v)
{
e[++num]=(node){ head[u],v };
head[u]=num;
}
void Tarjan(Rint k)
{
siz[k]++;
for(int i=head[k],x; i ;i=e[i].next)
{
x=e[i].to;
dep[x]=dep[k]+1,f[x]=k;
Tarjan( x ),siz[k]+=siz[x];
if( siz[x]>siz[son[k]] ) son[k]=x;
}
}
void dfs(Rint k)
{
dfn[k]=++cnt;
if( son[k] ) top[son[k]]=top[k],dfs( son[k] );
for(int i=head[k],x; i ;i=e[i].next)
{
x=e[i].to;
if( x==son[k] ) continue ;
top[x]=x,dfs( x );
}
}
void build(int &k,int l,int r)
{
if( !k ) k=++tot;
if( l==r ) return ;
int mid=(l+r)/2;
build( ls[k],l,mid ),build( rs[k],mid+1,r );
}
void insert(Rint &k,Rint lrt,Rint l,Rint r,Rint x)
{
if( !k ) k=++tot;
if( l==r ) { sum[k]=1;return ; }
int mid=(l+r)/2;
if( x<=mid ) rs[k]=rs[lrt],insert( ls[k],ls[lrt],l,mid,x );
else ls[k]=ls[lrt],insert( rs[k],rs[lrt],mid+1,r,x );
sum[k]=sum[ls[k]]+sum[rs[k]];
}
void Add(int k,int x)
{
for( ; k<=n ;k+=lowbit( k )) insert( rt[k],rt[k],1,q,x );
}
int query(Rint k,Rint l,Rint r,Rint L,Rint R)
{
if( l==L && r==R ) return sum[k];
int mid=(l+r)/2;
if( R<=mid ) return query( ls[k],l,mid,L,R );
else
if( L>=mid+1 ) return query( rs[k],mid+1,r,L,R );
else return query( ls[k],l,mid,L,mid )+query( rs[k],mid+1,r,mid+1,R );
}
int getsum(int l,int r)
{
int ret=0;
l--;
while( r ) ret+=query( rt[r],1,q,1,W ),r-=lowbit( r );
while( l ) ret-=query( rt[l],1,q,1,W ),l-=lowbit( l );
return ret;
}
int work(int u,int v)
{
int ret=0;
if( dep[u]<dep[v] ) swap( u,v );
while( top[u]!=top[v] )
{
if( dep[top[u]]<dep[top[v]] ) swap( u,v );
ret+=getsum( dfn[top[u]],dfn[u] );
u=f[top[u]];
}
if( dep[u]>dep[v] ) swap( u,v );
ret+=getsum( dfn[u],dfn[v] );
return ret;
}
int LCA(int u,int v)
{
while( top[u]!=top[v] )
{
if( dep[top[u]]<dep[top[v]] ) swap( u,v );
u=f[top[u]];
}
return dep[u]<dep[v] ? u : v ;
}
int main()
{
int opt,u,v,c;
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
scanf("%d",&u);
if( u ) add( u,i );
else srt=i;
}
dep[1]=1,Tarjan( srt );
top[1]=1,dfs( srt );
scanf("%d",&q);
build( rt[0],1,q );
for(int i=1;i<=q;i++)
{
scanf("%d",&opt);
if( opt==1 )
{
scanf("%d%d%d",&u,&v,&c),W=i-c-1;
printf("%d %d\n",dep[u]+dep[v]-2*dep[LCA( u,v )]+1, W>=1 ? work( u,v ) : 0 );
}
else scanf("%d",&u),Add( dfn[u],i );
}
return 0;
}