按照思路来写就行了
#include<bits/stdc++.h>
using namespace std;
struct node
{
int to,nxt;
}sq[200100];
int n,m,dep[100100],fa[100100],all=0,head[100100];
bool vis[100100];
int read()
{
int x=0,f=1;
char ch=getchar();
while((ch<'0')||(ch>'9'))
{
if(ch=='-')
f=-1;
ch=getchar();
}
while((ch>='0')&&(ch<='9'))
{
x=x*10+(ch-'0');
ch=getchar();
}
return x*f;
}
void add(int u,int v)
{
all++;
sq[all].to=v;
sq[all].nxt=head[u];
head[u]=all;
}
void dfs(int u,int fu)
{
dep[u]=dep[fu]+1;
fa[u]=fu;
int i;
for(i=head[u];i;i=sq[i].nxt)
if(sq[i].to!=fu)
dfs(sq[i].to,u);
}
int main()
{
n=read();
m=read();
int i;
for(i=1;i<n;i++)
{
int u=read(),v=read();
add(u,v);
add(v,u);
}
dep[0]=-1;
int ans=0;
dfs(1,0);
memset(vis,0,sizeof(vis));
vis[1]=1;
for(i=1;i<=m;i++)
{
int x=read(),j;
for(j=x;!vis[j];j=fa[j])
{
vis[j]=1;
ans+=2;
}
printf("%d\n",ans-dep[x]);
}
return 0;
}
来源:zr