这道题是叫做虚树的东西,可以发现每次的关键点个数较少,每次用O(n)的DP显然TLE。可以将所有关键点及关键点lca间连上虚边,对这O(m)个点DP,链上的点可以分成两段,分别计算贡献即可,细节蛮多,树链剖分比倍增快。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 300005
#define inf 1000000000
using namespace std;
int n,x,y,Q,m;
int first[N],next[N<<1],to[N<<1],l;
int fa[N],size[N],dep[N],ord[N],Mson[N],top[N],cnt,seq[N];
int a[N],b[N],Fa[N],X[N],Y[N],st[N],p[N],Ans[N],val[N];
bool cmp(int x,int y){return ord[x]<ord[y];}
void link(int x,int y)
{
to[++l]=y;next[l]=first[x];first[x]=l;
to[++l]=x;next[l]=first[y];first[y]=l;
}
void dfs(int x)
{
dep[x]=dep[fa[x]]+1;size[x]=1;
for (int i=first[x];i;i=next[i])
if (to[i]!=fa[x])
{
fa[to[i]]=x;dfs(to[i]);size[x]+=size[to[i]];
if (size[to[i]]>size[Mson[x]]) Mson[x]=to[i];
}
}
void dfs(int x,int y)
{
top[x]=y;ord[x]=++cnt;seq[cnt]=x;
if (Mson[x]) dfs(Mson[x],y);
for (int i=first[x];i;i=next[i])
if (to[i]!=fa[x]&&to[i]!=Mson[x]) dfs(to[i],to[i]);
}
int lca(int x,int y)
{
for (;top[x]!=top[y];x=fa[top[x]])
if (dep[top[x]]<dep[top[y]]) swap(x,y);
return dep[x]<dep[y]?x:y;
}
int Get(int x,int d)
{
for (;dep[top[x]]>d;x=fa[top[x]]);
return seq[ord[x]-dep[x]+d];
}
void Min(int x,int y,int z)
{
if(X[x]>X[y]+z||(X[x]==X[y]+z&&Y[x]>Y[y]))
X[x]=X[y]+z,Y[x]=Y[y];
}
int main()
{
scanf("%d",&n);
for (int i=1;i<n;i++)
scanf("%d%d",&x,&y),link(x,y);
dfs(1);dfs(1,1);
scanf("%d",&Q);
while(Q--)
{
int cnt=0;
scanf("%d",&m);
for (int i=1;i<=m;i++)
scanf("%d",&a[i]),p[++cnt]=b[i]=a[i],X[a[i]]=0,Y[a[i]]=a[i];
sort(a+1,a+m+1,cmp);
for (int i=1,tail=0;i<=m;i++)
if (tail)
{
int t=lca(st[tail],a[i]);
for (;dep[st[tail]]>dep[t];tail--)
if (dep[st[tail-1]]<=dep[t]) Fa[st[tail]]=t;
if (st[tail]!=t) Fa[t]=st[tail],st[++tail]=p[++cnt]=t,X[t]=inf,Y[t]=0;
st[++tail]=a[i];Fa[a[i]]=t;
}
else st[++tail]=a[i],Fa[a[i]]=0;
sort(p+1,p+cnt+1,cmp);
for (int i=cnt;i>=2;i--) Min(Fa[p[i]],p[i],dep[p[i]]-dep[Fa[p[i]]]);
for (int i=2;i<=cnt;i++) Min(p[i],Fa[p[i]],dep[p[i]]-dep[Fa[p[i]]]);
for (int i=1;i<=cnt;i++)
{
x=p[i];y=Fa[x];val[x]=size[x];
if (i==1) {Ans[Y[x]]+=n-size[x];continue;}
int t=Get(x,dep[y]+1),sum=size[t]-size[x];
val[y]-=size[t];
if (Y[x]==Y[y]) {Ans[Y[x]]+=sum;continue;}
t=X[x]-X[y]+dep[x]+dep[y]+2;
if (!(t&1)&&Y[y]>Y[x]) t-=2;t>>=1;
t=size[Get(x,t)]-size[x];
Ans[Y[x]]+=t;Ans[Y[y]]+=sum-t;
}
for (int i=1;i<=cnt;i++)
Ans[Y[p[i]]]+=val[p[i]];
for (int i=1;i<=m;i++)
printf("%d ",Ans[b[i]]),Ans[b[i]]=0;
puts("");
}
return 0;
}