这道题还是蛮有趣的 把每一个询问的链拆成s->lca 和lca->t 然后tarjan或者倍增求lca都可过(我的tarjan写丑了) 具体写法见大佬博客
我就放份代码 在联系新的代码风格
#include<bits/stdc++.h>
using namespace std;
const int N=3e5+5;
int n,m,tot=1,cnt=1;
int a[N],head[N],f[N],dep[N],ans[N],val[N],head2[N],tong[N],num[N*2];
struct Egde{
int id,v,nxt;
}E[N*2],e[N*2];
struct node{
int s,t,lca,len;
}b[N];
bool vis[N];
vector<int>tmp[N],tmp2[N],tmp3[N];
inline int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
void add(int u,int v){
e[++tot].v=v,e[tot].nxt=head[u],head[u]=tot;
e[++tot].v=u,e[tot].nxt=head[v],head[v]=tot;
}
void Add(int u,int v,int id){
E[++cnt].v=v,E[cnt].id=id,E[cnt].nxt=head2[u],head2[u]=cnt;
E[++cnt].v=u,E[cnt].id=id,E[cnt].nxt=head2[v],head2[v]=cnt;
}
int find(int x){
return f[x]=f[x]==x?x:find(f[x]);
}
void tarjan(int x,int fa){
f[x]=x;vis[x]=1;dep[x]=dep[fa]+1;
for(int i=head2[x];i;i=E[i].nxt){
int op=E[i].id;
if(vis[b[op].s]&&b[op].t==x) b[op].lca=find(b[op].s);
if(vis[b[op].t]&&b[op].s==x) b[op].lca=find(b[op].t);
}
for(int i=head[x];i;i=e[i].nxt){
int j=e[i].v;
if(j==fa) continue;
tarjan(j,x);
f[j]=x;
}
}
void dfs(int x,int fa){
int now1=dep[x]+a[x],now2=3e5+dep[x]-a[x];
int t1=tong[now1],t2=num[now2];
for(int i=head[x];i;i=e[i].nxt){
int j=e[i].v;
if(j==fa) continue;
dfs(j,x);
}
tong[dep[x]]+=val[x];
ans[x]+=tong[now1]-t1;
for(int i=0,s=tmp[x].size();i<s;i++) tong[dep[tmp[x][i]]]--;
for(int i=0,s=tmp2[x].size();i<s;i++) num[tmp2[x][i]]++;
ans[x]+=num[now2]-t2;
for(int i=0,s=tmp3[x].size();i<s;i++) num[tmp3[x][i]]--;
}
int main(){
n=read(),m=read();
for(int i=1,x,y;i<n;i++){
x=read(),y=read(),add(x,y);
}
for(int i=1;i<=n;i++) a[i]=read();
for(int i=1;i<=m;i++) b[i].s=read(),b[i].t=read(),val[b[i].s]++,Add(b[i].s,b[i].t,i);
tarjan(1,0);
for(int i=1;i<=m;i++){
b[i].len=dep[b[i].s]+dep[b[i].t]-2*dep[b[i].lca];
tmp[b[i].lca].push_back(b[i].s);
tmp2[b[i].t].push_back(dep[b[i].t]-b[i].len+3e5);
tmp3[b[i].lca].push_back(dep[b[i].t]-b[i].len+3e5);
}
dfs(1,0);
for(int i=1;i<=m;i++){
if(dep[b[i].s]-dep[b[i].lca]==a[b[i].lca])
ans[b[i].lca]--;
}
for(int i=1;i<n;i++) cout<<ans[i]<<" ";cout<<ans[n];
return 0;
}