http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemId=4969
题意清晰
思路:
- dfs后根遍历:给树形结构进行线性编号,保证子孙编号比老子要小,求出每个节点的子孙的线性区间[fa,nid];
- 常规线段树操作,注意延迟更新!
- 注意每组样例之后,输出一空行
#include<stdio.h>
#include<iostream>
#include<string.h>
#include<algorithm>
#include<map>
using namespace std;
#define M 100005
struct node
{
int oid,nid,fa;
}no[M];
struct node1
{
int from,to;
}edge[M];
struct node2
{
int cover,sum,l,r;
}stree[M*4];
int head[M],tot,now;
void addedge(int sup,int index)
{
edge[tot].to=index;
edge[tot].from=head[sup];
head[sup]=tot++;
}
void dfs(int st)
{
no[st].fa=M;
for(int i=head[st];i!=-1;i=edge[i].from)
{
int v=edge[i].to;
dfs(v);
no[st].fa=min(no[st].fa,no[v].fa);
}
no[st].nid=++now;
no[st].fa=min(no[st].fa,no[st].nid);
}
void init(int l,int r,int id)
{
stree[id].l=l;
stree[id].r=r;
stree[id].cover=0;
stree[id].sum=0;
if(r-l<=1)
{
return;
}
int mid=(l+r)/2;
init(l,mid,id*2);
init(mid,r,id*2+1);
}
void update(int l,int r,int key,int id)
{
if(stree[id].l==l&&stree[id].r==r)
{
stree[id].cover+=key;
stree[id].cover%=2;
stree[id].sum=(r-l)-stree[id].sum;
return;
}
if(stree[id].cover!=0)
{
stree[id*2].cover+=stree[id].cover;
stree[id*2].cover%=2;
stree[id*2].sum=(stree[id*2].r-stree[id*2].l)-stree[id*2].sum;
stree[id*2+1].cover+=stree[id].cover;
stree[id*2+1].cover%=2;
stree[id*2+1].sum=(stree[id*2+1].r-stree[id*2+1].l)-stree[id*2+1].sum;
stree[id].cover=0;
}
// if(stree[id].l<=l&&r<=stree[id].r)
{
int mid=(stree[id].l+stree[id].r)/2;
if(r<=mid) update(l,r,key,id*2);
else if(l>=mid) update(l,r,key,id*2+1);
else
{
update(l,mid,key,id*2);
update(mid,r,key,id*2+1);
}
}
stree[id].sum=stree[id*2].sum+stree[id*2+1].sum;
}
int search(int l,int r,int id)
{
if(stree[id].l==l&&r==stree[id].r)
{
return stree[id].sum;
}
if(stree[id].cover!=0)
{
stree[id*2].cover+=stree[id].cover;
stree[id*2].cover%=2;
stree[id*2].sum=(stree[id*2].r-stree[id*2].l)-stree[id*2].sum;
stree[id*2+1].cover+=stree[id].cover;
stree[id*2+1].cover%=2;
stree[id*2+1].sum=(stree[id*2+1].r-stree[id*2+1].l)-stree[id*2+1].sum;
stree[id].cover=0;
}
// if(stree[id].l<=l&&r<=stree[id].r)
{
int mid=(stree[id].l+stree[id].r)/2;
if(r<=mid) return search(l,r,id*2);
else if(l>=mid) return search(l,r,id*2+1);
else
{
return search(l,mid,id*2)+search(mid,r,id*2+1);
}
}
}
int main()
{
int n,i,m,k,j;
char c[3];
while(scanf("%d%d",&n,&m)!=EOF)
{
tot=now=0;
memset(head,-1,sizeof(head));
for(i=2;i<=n;i++)
{
scanf("%d",&no[i].fa);
no[i].oid=i;
addedge(no[i].fa,i);
}
dfs(1);
init(0,n,1);
while(m--)
{
scanf("%s%d",c,&k);
if(c[0]=='o')
{
update(no[k].fa-1,no[k].nid,1,1);
}
else
{
//puts("ss");
k=search(no[k].fa-1,no[k].nid,1);
printf("%d\n",k);
}
}
puts("");
}
}