树链剖分裸题
wa了一发。。。。。 是因为mark数组开小
不把数组开在一起容易开错空间啊
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
#define MAXN 100010
int seg[MAXN*4],sum[MAXN*4];
int fa[MAXN],size[MAXN],son[MAXN],top[MAXN],depth[MAXN],loc[MAXN];
int tot,g[MAXN*2],nnext[MAXN*2],num[MAXN*2];
int team[MAXN],head,tail;
int n,q;
void Add(int x,int y)
{
tot++;
nnext[tot]=g[x];
g[x]=tot;
num[tot]=y;
}
void Pou()
{
depth[1]=1;
team[++tail]=1;
while(head<tail)
{
int x=team[++head];
for(int i=g[x];i;i=nnext[i])
{
int tmp=num[i];
if(tmp==fa[x]) continue;
fa[tmp]=x;
team[++tail]=tmp;
}
}
for(int i=n;i>=1;i--)
{
int x=team[i];
size[x]=1;
for(int j=g[x];j;j=nnext[j])
{
int tmp=num[j];
if(tmp==fa[x]) continue ;
size[x]+=size[tmp];
if(size[tmp]>size[son[x]]) son[x]=tmp;
}
}
top[1]=1;
loc[1]=1;
for(int i=1;i<=n;i++)
{
int x=team[i];
int cnt=loc[x];
if(son[x]!=0)
{
loc[son[x]]=++cnt;
cnt+=size[son[x]]-1;
top[son[x]]=top[x];
}
for(int j=g[x];j;j=nnext[j])
{
int tmp=num[j];
if(loc[tmp]!=0) continue;
loc[tmp]=++cnt;
cnt+=size[tmp]-1;
top[tmp]=tmp;
}
}
}
int mark[MAXN*4];
void Pushdown(int now,int l,int r)
{
if(mark[now]==0||l==r) return ;
int k=mark[now]-1;mark[now]=0;
mark[now*2]=mark[now*2+1]=k+1;
int mid=(l+r)/2;
seg[now*2]=k*(mid-l+1);
seg[now*2+1]=k*(r-(mid+1)+1);
}
void Change(int now,int l,int r,int s,int t,int k)
{
if(s<=l&&r<=t)
{
seg[now]=(r-l+1)*k;
mark[now]=k+1;
return ;
}
Pushdown(now,l,r);
int mid=(l+r)/2;
if(s<=mid) Change(now*2,l,mid,s,t,k);
if(mid+1<=t) Change(now*2+1,mid+1,r,s,t,k);
seg[now]=seg[now*2]+seg[now*2+1];
}
int Q(int now,int l,int r,int s,int t)
{
if(s<=l&&r<=t)
{
return seg[now];
}
Pushdown(now,l,r);
int mid=(l+r)/2,ans=0;
if(s<=mid) ans+=Q(now*2,l,mid,s,t);
if(mid+1<=t) ans+=Q(now*2+1,mid+1,r,s,t);
return ans;
}
int Solve(int x)
{
int ans=0;
while(top[x]!=1)
{
ans+=(loc[x]-loc[top[x]]+1)-Q(1,1,n,loc[top[x]],loc[x]);
Change(1,1,n,loc[top[x]],loc[x],1);
x=fa[top[x]];
}
ans+=(loc[x]-loc[1]+1)-Q(1,1,n,loc[1],loc[x]);
Change(1,1,n,loc[1],loc[x],1);
return ans;
}
int main()
{
cin >>n;
for(int i=1;i<n;i++)
{
int x;
scanf("%d",&x);
Add(i+1,x+1);
Add(x+1,i+1);
}
Pou();
cin >>q;
while(q--)
{
char s[200];
int x;
scanf("%s %d",s+1,&x);x++;
if(s[1]=='u')
{
printf("%d\n",Q(1,1,n,loc[x],loc[x]+size[x]-1));
Change(1,1,n,loc[x],loc[x]+size[x]-1,0);
}
else
{
printf("%d\n",Solve(x));
}
}
return 0;
}