正题
这题十分的简单,就是修改树上的多条路径(每个位置加一),然后输出。
动态修改就可以想到树链剖分,所以我们用树链剖分维护一下就可以了。很水不多说
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
using namespace std;
int n;
int list[300010];
struct edge{
int y,next;
}e[600010];
int len;
int first[300010];
int dep[300010],tot[300010],top[300010],image[300010],fa[300010],son[300010];
long long ans[300010];
struct tree{
int ls,rs,x,y;
long long tot;
int lazy;
}s[600010];
void build(int x,int y){
len++;
int i=len;
s[i].ls=s[i].rs=-1;
s[i].x=x;s[i].y=y;
s[i].lazy=0;s[i].tot=0;
if(x==y) return ;
int mid=(x+y)/2;
s[i].ls=len+1;build(x,mid);
s[i].rs=len+1;build(mid+1,y);
}
void ins(int x,int y){
len++;
e[len].y=y;e[len].next=first[x];first[x]=len;
}
void dfs_1(int x){
tot[x]=1;
for(int i=first[x];i!=0;i=e[i].next){
int y=e[i].y;
if(y!=fa[x]){
fa[y]=x;
dep[y]=dep[x]+1;
dfs_1(y);
if(tot[son[x]]<tot[y]) son[x]=y;
tot[x]+=tot[y];
}
}
}
void dfs_2(int x,int tp){
top[x]=tp;image[x]=++len;
if(son[x]!=0) dfs_2(son[x],tp);
for(int i=first[x];i!=0;i=e[i].next){
int y=e[i].y;
if(y!=son[x] && y!=fa[x]) dfs_2(y,y);
}
}
void pushdown(int x){
if(s[x].lazy==0) return ;
int ls=s[x].ls,rs=s[x].rs;
s[ls].lazy+=s[x].lazy;
s[rs].lazy+=s[x].lazy;
s[ls].tot+=s[x].lazy*(s[ls].y-s[ls].x+1);
s[rs].tot+=s[x].lazy*(s[rs].y-s[rs].x+1);
s[x].lazy=0;
}
void add(int now,int x,int y,int c){
if(s[now].x==x && s[now].y==y){
s[now].tot+=(s[now].y-s[now].x+1)*c;
s[now].lazy+=c;
return ;
}
pushdown(now);
int mid=s[s[now].ls].y;
if(y<=mid) add(s[now].ls,x,y,c);
else if(mid<x) add(s[now].rs,x,y,c);
else {add(s[now].ls,x,mid,c);add(s[now].rs,mid+1,y,c);}
s[now].tot=s[s[now].ls].tot+s[s[now].rs].tot;
}
long long get_sum(int now,int x,int y){
if(s[now].x==x && s[now].y==y)
return s[now].tot;
pushdown(now);
int mid=s[s[now].ls].y;
if(y<=mid) return get_sum(s[now].ls,x,y);
else if(mid<x) return get_sum(s[now].rs,x,y);
else return get_sum(s[now].ls,x,mid)+get_sum(s[now].rs,mid+1,y);
}
void solve(int x,int y){
add(1,image[x],image[x],-1);
int tx=top[x],ty=top[y];
while(tx!=ty){
if(dep[tx]>dep[ty]){
swap(tx,ty);
swap(x,y);
}
add(1,image[ty],image[y],1);
y=fa[ty];ty=top[y];
}
if(dep[x]>dep[y]) swap(x,y);
add(1,image[x],image[y],1);
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%d",&list[i]);
for(int i=1;i<=n-1;i++){
int x,y;
scanf("%d %d",&x,&y);
ins(x,y);
ins(y,x);
}
dep[1]=1;dfs_1(1);
len=0;dfs_2(1,1);
len=0;
build(1,n);
add(1,image[list[1]],image[list[1]],1);
for(int i=2;i<=n;i++)
solve(list[i-1],list[i]);
for(int i=1;i<=n;i++)
ans[i]=get_sum(1,image[i],image[i]);
ans[list[n]]--;
for(int i=1;i<=n;i++)
printf("%lld\n",ans[i]);
}