题目描述:
题目分析:
算中点的时候注意一下中间点的划分。
不要漏了非虚树边的贡献。
Code:
#include<bits/stdc++.h>
#define maxn 300005
using namespace std;
int n,m,k,a[maxn],b[maxn],ans[maxn],dep[maxn],fa[maxn],siz[maxn],son[maxn],top[maxn],dfn[maxn],tim;
int ptt[maxn],*pt[maxn]={ptt,ptt};
int fir[maxn],nxt[maxn<<1],to[maxn<<1],tot;
inline void line(int x,int y){nxt[++tot]=fir[x],fir[x]=tot,to[tot]=y;}
bool cmp(int i,int j){return dfn[i]<dfn[j];}
void dfs1(int u,int ff){
dep[u]=dep[fa[u]=ff]+1,siz[u]=1;
for(int i=fir[u],v;i;i=nxt[i]) if((v=to[i])!=ff){
dfs1(v,u),siz[u]+=siz[v];
if(siz[v]>siz[son[u]]) son[u]=v;
}
}
void dfs2(int u,int tp){
top[u]=tp,dfn[u]=++tim,*pt[u]=u,pt[0]++;
if(son[u]) pt[son[u]]=pt[u]+1,dfs2(son[u],tp);
for(int i=fir[u],v;i;i=nxt[i]) if(!dfn[v=to[i]]) pt[v]=pt[0],dfs2(v,v);
}
inline int LCA(int u,int v){
for(;top[u]!=top[v];u=fa[top[u]]) if(dep[top[u]]<dep[top[v]]) swap(u,v);
return dep[u]<dep[v]?u:v;
}
inline int getk(int u,int k){
for(int v;dep[u]-dep[v=fa[top[u]]]<=k;k-=dep[u]-dep[v],u=v);
return *(pt[u]-k);
}
const int inf = 1e9;
namespace VTree{
int S[maxn],top,sz;
bool kp[maxn];
struct node{
int x,d;
node operator + (int t){return (node){x,d+t};}
bool operator < (const node &p)const{return d==p.d?x<p.x:d<p.d;}
}B[maxn];
void dfs1(int u){
B[u]=kp[u]?(node){u,0}:(node){0,inf};
for(int i=fir[u],v;i;i=nxt[i]) dfs1(v=to[i]),B[u]=min(B[u],B[v]+(dep[v]-dep[u]));
}
void dfs2(int u){
int res=siz[u];//!!!
for(int i=fir[u],v,w;i;i=nxt[i]){
v=to[i],B[v]=min(B[v],B[u]+(w=dep[v]-dep[u])),dfs2(v);
int t=getk(v,w-1); res-=siz[t];
if(B[u].x==B[v].x) {ans[B[u].x]+=siz[t]-siz[v];continue;}
//then 'o' must be on v~u-1,
int len=B[u].d+B[v].d+w+1,d=len/2-(B[v].d+1),o=d==-1?v:getk(v,d);
if((len&1)&&d>=0&&B[v].x<B[u].x) o=fa[o];
ans[B[v].x]+=siz[o]-siz[v],ans[B[u].x]+=siz[t]-siz[o];
}
ans[B[u].x]+=res;
}
void solve(int *a){
sort(a+1,a+1+k,cmp),sz=k;
for(int i=1;i<=k;i++){
kp[a[i]]=1; if(!top) {S[++top]=a[i];continue;}
int lca=LCA(a[i],S[top]);
for(;top>1&&dfn[lca]<=dfn[S[top-1]];top--) line(S[top-1],S[top]);
if(S[top]!=lca) line(lca,S[top]),a[++sz]=S[top]=lca;
S[++top]=a[i];
}
for(;top>1;top--) line(S[top-1],S[top]);
dfs1(S[1]),dfs2(S[1]);
ans[B[S[1]].x]+=n-siz[S[1]];//S[1] may be not a key point
for(tot=top=0;sz;sz--) fir[a[sz]]=kp[a[sz]]=0;
}
}
int main()
{
scanf("%d",&n);
for(int i=1,x,y;i<n;i++) scanf("%d%d",&x,&y),line(x,y),line(y,x);
dfs1(1,0),dfs2(1,1),memset(fir,0,sizeof fir),tot=0;
scanf("%d",&m);
while(m--){
scanf("%d",&k);
for(int i=1;i<=k;i++) scanf("%d",&a[i]),b[i]=a[i];
VTree::solve(b);
for(int i=1;i<=k;i++) printf("%d%c",ans[a[i]]," \n"[i==k]),ans[a[i]]=0;
}
}