很明显的虚树题。
把关键点的虚树构建出来,然后可以两遍遍历得到离点i最近的关键点的距离和编号。那么现在考虑虚树中的一对点(x,y),x为y的某个儿子,考虑其对答案的影响。
由于是虚树,那么显然所有y->x的路径上的点,这个点延伸出去的点中(不包含由y->x的路径)不会有关键点存在,那么离这些点最近的虚树中的点,要么是x,要么是y,而且一定是先到达y->x的路径上的某一点,然后到达x或y,最后到达关键点。那么y->x的路径上有有一个分界点z,使得y->z的路径上延伸出去的点都由y到达关键点,z->x都由x到达关键点。而x到关键点的路径一定是往下走,y则一定是先往上走,因此求出关于x个y的两个关键点的中点,就是z了。
然后用子树大小统计路径及其延伸出去的点的个数即可,最后要加上那些没有被统计到的点。
AC代码如下:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 300005
#define inf 1000000000
using namespace std;
int n,m,tot,dfsclk,bin[25],fst[N],pnt[N<<1],nxt[N<<1],fa[N][19],d[N],pos[N];
int a[N],p[N],id[N],val[N],anc[N],len[N],ans[N],sz[N],q[N]; struct node{ int x,y; }g[N];
int read(){
int x=0; char ch=getchar();
while (ch<'0' || ch>'9') ch=getchar();
while (ch>='0' && ch<='9'){ x=x*10+ch-'0'; ch=getchar(); }
return x;
}
void add(int x,int y){
pnt[++tot]=y; nxt[tot]=fst[x]; fst[x]=tot;
}
void dfs(int x){
pos[x]=++dfsclk; sz[x]=1; int p,i;
for (i=1; bin[i]<=x; i++) fa[x][i]=fa[fa[x][i-1]][i-1];
for (p=fst[x]; p; p=nxt[p]){
int y=pnt[p];
if (y!=fa[x][0]){
fa[y][0]=x; d[y]=d[x]+1;
dfs(y); sz[x]+=sz[y];
}
}
}
int lca(int x,int y){
if (d[x]<d[y]) swap(x,y); int tmp=d[x]-d[y],i;
for (i=0; bin[i]<=tmp; i++)
if (tmp&bin[i]) x=fa[x][i];
for (i=18; i>=0; i--)
if (fa[x][i]!=fa[y][i]){ x=fa[x][i]; y=fa[y][i]; }
return (x==y)?x:fa[x][0];
}
int find(int x,int dep){
int i; for (i=18; i>=0; i--) if (d[fa[x][i]]>=dep) x=fa[x][i];
return x;
}
bool cmp(int x,int y){ return pos[x]<pos[y]; }
bool lss(node u,node v){
return u.x<v.x || u.x==v.x && u.y<v.y;
}
void solve(){
m=read(); int i,cnt=m,tp=0; node t;
for (i=1; i<=m; i++){
a[i]=id[i]=p[i]=read(); g[a[i]].y=a[i];
g[a[i]].x=ans[a[i]]=0;
}
sort(a+1,a+m+1,cmp);
for (i=1; i<=m; i++)
if (!tp){ q[++tp]=a[i]; anc[a[i]]=0; } else{
int tmp=lca(a[i],q[tp]);
for (; d[q[tp]]>d[tmp]; tp--)
if (d[q[tp-1]]<=d[tmp]) anc[q[tp]]=tmp;
if (q[tp]!=tmp){
p[++cnt]=tmp; anc[tmp]=q[tp];
q[++tp]=tmp; g[tmp].x=inf; g[tmp].y=0;
}
anc[a[i]]=tmp; q[++tp]=a[i];
}
sort(p+1,p+cnt+1,cmp);
for (i=1; i<=cnt; i++){
int x=p[i]; val[x]=sz[x];
if (i>1) len[x]=d[x]-d[anc[x]];
}
for (i=cnt; i>1; i--){
int x=p[i]; t=g[x]; t.x+=len[x];
if (lss(t,g[anc[x]])) g[anc[x]]=t;
}
for (i=2; i<=cnt; i++){
int x=p[i]; t=g[anc[x]]; t.x+=len[x];
if (lss(t,g[x])) g[x]=t;
}
for (i=1; i<=cnt; i++){
int x=p[i],y=anc[x];
if (i==1) ans[g[x].y]+=n-sz[x]; else{
int tmp=find(x,d[y]+1),sum=sz[tmp]-sz[x];
val[y]-=sz[tmp];
if (g[x].y==g[y].y) ans[g[x].y]+=sum; else{
int z=d[x]-((g[y].x+len[x]-g[x].x)>>1);
if (!((g[y].x+g[x].x+len[x])&1) && g[x].y>g[y].y) z++;
z=sz[find(x,z)]-sz[x];
ans[g[x].y]+=z; ans[g[y].y]+=sum-z;
}
}
}
for (i=1; i<=cnt; i++) ans[g[p[i]].y]+=val[p[i]];
for (i=1; i<=m; i++) printf("%d ",ans[id[i]]); puts("");
}
int main(){
n=read(); int i;
bin[0]=1; for (i=1; i<=19; i++) bin[i]=bin[i-1]<<1;
for (i=1; i<n; i++){
int x=read(),y=read();
add(x,y); add(y,x);
}
d[1]=1; dfs(1);
int cas=read(); while (cas--) solve();
return 0;
}
by lych
2016.3.7