题目传送门
一道大水题。
解法:
简单的树剖。
到这个点这个点就要加1。
线段树维护。
这题有一点点坑就是除了第一个点其他最后到的点都要少一颗糖果。
因为他去吃大餐了。
代码实现:
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<queue>
using namespace std;
struct node {int x,y,next;}a[1100000];int len,last[310000];
void ins(int x,int y) {len++;a[len].x=x;a[len].y=y;a[len].next=last[x];last[x]=len;}
int n,fa[310000],dep[310000],tot[310000],son[310000];
void pre_tree_node(int x) {
tot[x]=1;son[x]=0;
for(int k=last[x];k;k=a[k].next) {
int y=a[k].y;
if(y!=fa[x]) {
fa[y]=x;dep[y]=dep[x]+1;
pre_tree_node(y);if(tot[son[x]]<tot[y])son[x]=y;tot[x]+=tot[y];
}
}
}
int z,ys[310000],top[310000];
void pre_tree_edge(int x,int tp) {
top[x]=tp;ys[x]=++z;
if(son[x]!=0)pre_tree_edge(son[x],tp);
for(int k=last[x];k;k=a[k].next) {
int y=a[k].y;
if(y!=fa[x]&&y!=son[x])pre_tree_edge(y,y);
}
}
struct trnode {int l,r,lc,rc,c,lazy;}tr[1100000];int trlen;
void bt(int l,int r) {
int now=++trlen;
tr[now].l=l;tr[now].r=r;tr[now].lc=tr[now].rc=-1;tr[now].c=0;tr[now].lazy=0;
if(l<r) {
int mid=(l+r)/2;
tr[now].lc=trlen+1;bt(l,mid);tr[now].rc=trlen+1;bt(mid+1,r);
}
}
void update(int now) {
int lc=tr[now].lc,rc=tr[now].rc;
tr[lc].c+=tr[now].lazy;tr[rc].c+=tr[now].lazy;
tr[lc].lazy+=tr[now].lazy;tr[rc].lazy+=tr[now].lazy;tr[now].lazy=0;
}
void change(int now,int l,int r) {
if(tr[now].l==l&&tr[now].r==r) {
tr[now].lazy+=1;tr[now].c+=1;return ;
}
if(tr[now].lazy!=0)update(now);
int lc=tr[now].lc,rc=tr[now].rc,mid=(tr[now].l+tr[now].r)/2;
if(r<=mid)change(lc,l,r);
else if(l>mid)change(rc,l,r);
else {change(lc,l,mid),change(rc,mid+1,r);}
}
int findsum(int now,int x) {
if(tr[now].l==tr[now].r)return tr[now].c;
if(tr[now].lazy!=0)update(now);
int lc=tr[now].lc,rc=tr[now].rc,mid=(tr[now].l+tr[now].r)/2;
if(x<=mid)return findsum(lc,x);
else return findsum(rc,x);
}
void solve(int x,int y) {
int tx=top[x],ty=top[y];
while(tx!=ty) {
if(dep[tx]>dep[ty]) {swap(tx,ty);swap(x,y);}
change(1,ys[ty],ys[y]);y=fa[ty];ty=top[y];
}
if(dep[x]>dep[y])swap(x,y);change(1,ys[x],ys[y]);
}
int s[310000];
int main() {
int n;scanf("%d",&n);len=0;memset(last,0,sizeof(last));
for(int i=1;i<=n;i++)scanf("%d",&s[i]);
for(int i=1;i<n;i++) {int x,y;scanf("%d%d",&x,&y);ins(x,y);ins(y,x);}
fa[1]=0;dep[1]=0;pre_tree_node(1);
z=0;pre_tree_edge(1,1);
trlen=0;bt(1,z);
for(int i=2;i<=n;i++)solve(s[i],s[i-1]);
for(int i=1;i<=n;i++){int ans=findsum(1,ys[i]);if(s[1]!=i)ans--;printf("%d\n",ans);}
return 0;
}