产生dfs序列之后,记录每个点对于的L[i]和R[i],每个点对应的子树就确定了。设计一种处理顺序,使得每个点的不被多次重复计算。
对L[i]进行分块,将问题分类,设每块大小为S。同时处理Id=L[i]/S相同的节点:
(1)如果R[i]/S=id,则将该询问放到ask1中
(2)如果R[i]/S!=id,则将该询问放到ask2中。
对于ask1,直接暴力,处理每个问题。复杂度O(√N*√N)=O(N)
对于ask2,先按左端点排序,排序之后,每个节点对应的区间不会相交,应该dfs序列是合法的括号序列,区间一定满足[ [ [ ] ] ]
先处理最外面的节点,再依次删除,逐步回答里面的节点。复杂度(√N*log√N+N)=O(N)
块的个数√N个,所有总的复杂度O(N√N),具体实现键代码:
#include<stdio.h>
#include<string.h>
#include<iostream>
#include<algorithm>
#include<vector>
#include<cmath>
#include<map>
#define M 100005
#pragma comment(linker, "/STACK:102400000,102400000")
using namespace std;
struct node{
int x,y,id;
};
map<int,int>se;
vector<int>edge[M];
vector<node>bucket[M];
vector<node>ask1;
vector<node>ask2;
int L[M],R[M],col[M],T,ds[M],SZ,K;
int ans[M],cnt[M],sum[M];
bool cmp(node A,node B){
return A.x<B.x;
}
void dfs(int x,int f){
ds[++T]=x; L[x]=T;
for(int i=0;i<edge[x].size();i++){
int y=edge[x][i];
if(y==f)continue;
dfs(y,x);
}
R[x]=T;
}
void Add(int x){
cnt[sum[x]]--;sum[x]++;cnt[sum[x]]++;
}
void Del(int x){
cnt[sum[x]]--;sum[x]--;cnt[sum[x]]++;
}
void solve1(){
memset(cnt,0,sizeof(cnt));
memset(sum,0,sizeof(sum));
int i,j,k;
for(i=0;i<ask1.size();i++){
for(j=ask1[i].x;j<=ask1[i].y;j++){
Add(col[ds[j]]);
}
ans[ask1[i].id]=cnt[K];
for(j=ask1[i].x;j<=ask1[i].y;j++){
Del(col[ds[j]]);
}
}
}
void solve2(){
sort(ask2.begin(),ask2.end(),cmp);
memset(cnt,0,sizeof(cnt));
memset(sum,0,sizeof(sum));
int i,j,k;
for(i=ask2[0].x;i<=ask2[0].y;i++)
Add(col[ds[i]]);
ans[ask2[0].id]=cnt[K];
int L=ask2[0].x,R=ask2[0].y;
for(i=1;i<ask2.size();i++){
int x=ask2[i].x;
int y=ask2[i].y;
while(L<x)Del(col[ds[L++]]);
while(R>y)Del(col[ds[R--]]);
ans[ask2[i].id]=cnt[K];
}
}
int main(){
int n,m,i,j,a,b,cas,v=1;
scanf("%d",&cas);
while(cas--&&scanf("%d %d",&n,&K)){
se.clear();
int ID=1;
for(i=1;i<=n;i++){
edge[i].clear();
scanf("%d",&col[i]);
if(se[col[i]]==0)se[col[i]]=col[i]=ID++;
else col[i]=se[col[i]];
}
for(i=1;i<n;i++){
scanf("%d %d",&a,&b);
edge[a].push_back(b);
edge[b].push_back(a);
}
T=-1;
dfs(1,0);
SZ=int(sqrt(n*1.0));
int p=n/SZ;
for(i=p;i>=0;i--)
bucket[i].clear();
node tmp;
for(i=1;i<=n;i++){
tmp.x=L[i],tmp.y=R[i],tmp.id=i;
bucket[L[i]/SZ].push_back(tmp);
}
for(i=0;i<=p;i++){
ask1.clear();
ask2.clear();
for(j=0;j<bucket[i].size();j++){
if(bucket[i][j].y/SZ==i){
ask1.push_back(bucket[i][j]);
}else{
ask2.push_back(bucket[i][j]);
}
}
if(ask1.size()>0) solve1();
if(ask2.size()>0) solve2();
}
if(v>1)puts("");
printf("Case #%d:\n",v++);
scanf("%d",&m);
for(i=0;i<m;i++){
scanf("%d",&a);
printf("%d\n",ans[a]);
}
}
return 0;
}