一看数据范围便能猜到要么是O(n),要么是n*log(n)的算法,并且这个查询和更改操作使人很自然地想起了树状数组,但是树状数组只能对区间进行操作,而题目的数据给出的是树的形式,需要将根节点和它的子树这个范围变成[ , ]区间的形式。我们可以从根节点出发用时间戳st[]记录开始位置,用end[]
记录结束位置,dfs根节点,求出每个点的区间,然后就变成了对区间的操作,查询区间操作依然是用r-(l-1) ;因为树状数组是从1开始记录的!!!
#include<iostream>
#include<cstdio>
using namespace std;
struct node{
int from,to;
}list[100005];
int head[100005],t[100005],st[100005],end[100005],tot,n,q,x,y,s,m;
char ch;
bool v[100005];
void add(int x,int y){
list[++s].from=head[x];
list[s].to=y;
head[x]=s;
}
int lowbit(int x){
return x&-x;
}
void dfs(int x){
st[x]=++tot;
for (int i=head[x];i;i=list[i].from)
{
if (!v[list[i].to])
dfs(list[i].to),v[list[i].to]=1;
}
end[x]=tot;
}
void up(int x,int y){
while(x<=n){
t[x]+=y;
x+=lowbit(x);
}
}
int query(int x){
int sum=0;
while(x){
sum+=t[x];
x-=lowbit(x);
}
return sum;
}
int main(){
cin>>n;
for (int i=1;i<n;i++){
scanf("%d%d",&x,&y);
add(x,y);
}
dfs(1);
for (int i=1;i<=n;i++)
up(i,1);
cin>>m;
for (int i=1;i<=m;i++){
getchar();
scanf("%c%d",&ch,&q);
if(ch=='Q'){
cout<<(query(end[q])-query(st[q]-1))<<endl;
}
else {
if (query(st[q])-query(st[q]-1)==1) up(st[q],-1);
else up(st[q],1);
}
}
return 0;
}