题意
给定一棵n个点的有根树,编号依次为1到n,其中1号点是根节点。每个节点都被染上了某一种颜色,其中第i个节点的颜色为c[i]。如果c[i]=c[j],那么我们认为点i和点j拥有相同的颜色。定义depth[i]为i节点与根节点的距离,为了方便起见,你可以认为树上相邻的两个点之间的距离为1。站在这棵色彩斑斓的树前面,你将面临m个问题。每个问题包含两个整数x和d,表示询问x子树里且depth不超过depth[x]+d的所有点中出现了多少种本质不同的颜色。请写一个程序,快速回答这些询问。
第一行包含一个正整数T(1<=T<=500),表示测试数据的组数。
每组数据中,第一行包含两个正整数n(1<=n<=100000)和m(1<=m<=100000),表示节点数和询问数。
第二行包含n个正整数,其中第i个数为ci,分别表示每个节点的颜色。
第三行包含n-1个正整数,其中第i个数为fi+1,表示节点i+1的父亲节点的编号。
接下来m行,每行两个整数x(1<=x<=n)和d(0<=d
分析
之前做过一道类似的题是询问节点权值和,用的是KDtree来做,所以感觉思维被限制住了。
又是比较牛逼的一道题。
考虑如果没有深度限制的话,每个节点的贡献是1,两个相同颜色的节点贡献会算重,所以要在lca处-1。
推广到多个点,就是把相同颜色的点按dfs序排序后,相邻两个点的lca处贡献要-1。
现在多了深度的限制,我们可以按深度建可持久化线段树,来维护所有深度不大于某个值的节点贡献。
然后用n个set来维护每种颜色的dfs序,插入的时候在里面乱搞一下即可。
bin数组开小了1害得我多调了1h。。。
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<set>
using namespace std;
const int N=100005;
int n,m,dfn[N],mx[N],pos[N],tot,sz,tim,cnt,last[N],col[N],dep[N],fa[N],rmq[N*2][20],bin[20],root[N],lg[N*2],a[N];
set<pair<int,int> > se[N];
struct tree{int l,r,s;}t[N*80];
struct edge{int to,next;}e[N];
int read()
{
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
bool cmp_dep(int x,int y)
{
return dep[x]<dep[y];
}
void addedge(int u,int v)
{
e[++cnt].to=v;e[cnt].next=last[u];last[u]=cnt;
}
void dfs(int x)
{
dfn[x]=++tim;rmq[++tot][0]=x;pos[x]=tot;dep[x]=dep[fa[x]]+1;
for (int i=last[x];i;i=e[i].next) dfs(e[i].to),rmq[++tot][0]=x;
mx[x]=tim;
}
void pre_rmq()
{
bin[0]=1;
for (int i=1;i<=18;i++) bin[i]=bin[i-1]*2;
for (int i=1;i<=tot;i++) lg[i]=log(i)/log(2);
for (int j=1;j<=lg[tot];j++)
for (int i=1;i+bin[j]-1<=tot;i++)
rmq[i][j]=dep[rmq[i][j-1]]<dep[rmq[i+bin[j-1]][j-1]]?rmq[i][j-1]:rmq[i+bin[j-1]][j-1];
}
int get_lca(int x,int y)
{
x=pos[x];y=pos[y];
if (x>y) swap(x,y);
int w=lg[y-x+1];
return dep[rmq[x][w]]<dep[rmq[y-bin[w]+1][w]]?rmq[x][w]:rmq[y-bin[w]+1][w];
}
void ins(int &d,int p,int l,int r,int x,int y)
{
d=++sz;t[d]=t[p];t[d].s+=y;
if (l==r) return;
int mid=(l+r)/2;
if (x<=mid) ins(t[d].l,t[p].l,l,mid,x,y);
else ins(t[d].r,t[p].r,mid+1,r,x,y);
}
int query(int d,int l,int r,int x,int y)
{
if (x>y) return 0;
if (l==x&&r==y) return t[d].s;
int mid=(l+r)/2;
return query(t[d].l,l,mid,x,min(y,mid))+query(t[d].r,mid+1,r,max(x,mid+1),y);
}
int main()
{
int T=read();
while (T--)
{
n=read();m=read();cnt=tot=sz=tim=0;
for (int i=1;i<=n;i++) col[i]=read(),last[i]=root[i]=0,se[i].clear();
for (int i=2;i<=n;i++) fa[i]=read(),addedge(fa[i],i);
dfs(1);pre_rmq();
for (int i=1;i<=n;i++) a[i]=i;
sort(a+1,a+n+1,cmp_dep);
for (int i=1;i<=n;i++)
{
int x=a[i],c=col[x],d=dep[x],p=0,q=0;
if (d!=dep[a[i-1]]) root[d]=root[d-1];
set<pair<int,int> >::iterator it=se[c].lower_bound(make_pair(dfn[x],x));
if (it!=se[c].end()) q=(*it).second;
if (it!=se[c].begin()) it--,p=(*it).second;
se[c].insert(make_pair(dfn[x],x));
if (p) ins(root[d],root[d],1,n,dfn[get_lca(x,p)],-1);
if (q) ins(root[d],root[d],1,n,dfn[get_lca(x,q)],-1);
if (p&&q) ins(root[d],root[d],1,n,dfn[get_lca(p,q)],1);
ins(root[d],root[d],1,n,dfn[x],1);
}
for (int i=1;i<=n;i++) if (!root[i]) root[i]=root[i-1];
int ans=0;
while (m--)
{
int x=read()^ans,d=read()^ans;
printf("%d\n",ans=query(root[min(n,dep[x]+d)],1,n,dfn[x],mx[x]));
}
}
return 0;
}