题目传送门:http://www.lydsy.com/JudgeOnline/problem.php?id=3572
题目分析:最近跟着tututu入坑虚树,就做了这道题。
虚树教程:https://www.cnblogs.com/zzqsblog/p/5560645.html
虚树大概的思想就是:如果每次询问给出一个点集,在原先的树上跑一遍求答案代价太大,我们可以尝试着构造一个棵小一点的树。这棵树仅包含询问的点集以及它们之间的LCA,它的大小不会超过两倍的点集大小。这样小树上的每一条链就记录着很多点的信息。而只要在这棵小树上求答案,总时间就可以和点集总大小同阶。这颗新构造出来的树就叫做虚树。而针对一个点集构造对应的虚树,可以通过一个栈+LCA实现。
这题就是道虚树的入门题。我们将每次询问的虚树构出来,然后在上面做两次DFS,求出距离虚树上每个点最近的关键点在哪里。然后枚举虚树上的一条边(u,v),设u是v的父亲。倍增出原树中深度最小的mid使得mid被离v最近的那个关键点管辖,然后通过对原树预处理Size即可求得答案。
这题的主要特点是思路简单,但代码很烦,做了我一个晚上。其间我还重构了几个部分的代码,调试的时候debug出无数错误。不过样例很强,只要过了样例就能AC了。注意虚树在构造的过程中,栈内点的父亲随时可能变。虚树上DFS的时候边权为dep[v]-dep[u],其中dep是点在原树中的深度。每次询问前要确保初始化了所有相关点的数组变量。仔细考虑如何计算答案。
CODE:
#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<stdio.h>
#include<algorithm>
using namespace std;
const int maxn=300100;
const int maxl=21;
struct edge
{
int obj;
edge *Next;
} e[maxn<<2];
edge *head[maxn];
int cur=-1;
int fa[maxn][maxl];
int Size[maxn];
int dep[maxn];
int dfn[maxn];
int Time=0;
int Node[maxn];
int ori[maxn];
int Fa[maxn];
int ans[maxn];
int cnt;
int dis1[maxn];
int dis2[maxn];
int id1[maxn];
int id2[maxn];
bool vis[maxn];
int sak[maxn];
int tail;
int n,q,m;
void Add(int x,int y)
{
cur++;
e[cur].obj=y;
e[cur].Next=head[x];
head[x]=e+cur;
}
void Dfs1(int node)
{
dfn[node]=++Time;
Size[node]=1;
for (edge *p=head[node]; p; p=p->Next)
{
int son=p->obj;
if (son==fa[node][0]) continue;
fa[son][0]=node;
dep[son]=dep[node]+1;
Dfs1(son);
Size[node]+=Size[son];
}
}
bool Comp(int x,int y)
{
return dfn[x]<dfn[y];
}
int Lca(int x,int y)
{
if (dep[x]<dep[y]) swap(x,y);
for (int j=maxl-1; j>=0; j--)
if (dep[ fa[x][j] ]>=dep[y]) x=fa[x][j];
if (x==y) return x;
for (int j=maxl-1; j>=0; j--)
if (fa[x][j]!=fa[y][j]) x=fa[x][j],y=fa[y][j];
return fa[x][0];
}
void Update(int node,int id,int dis)
{
if ( dis<dis1[node] || ( dis==dis1[node] && id<id1[node] ) )
{
dis2[node]=dis1[node];
id2[node]=id1[node];
dis1[node]=dis;
id1[node]=id;
}
else
if ( dis<dis2[node] || ( dis==dis2[node] && id<id2[node] ) )
{
dis2[node]=dis;
id2[node]=id;
}
}
void Dfs2(int node)
{
if (vis[node]) dis1[node]=0,id1[node]=node;
for (edge *p=head[node]; p; p=p->Next)
{
int son=p->obj;
Dfs2(son);
Update(node,id1[son],dis1[son]+dep[son]-dep[node]);
}
}
void Dfs3(int node,int id,int dis)
{
Update(node,id,dis);
for (edge *p=head[node]; p; p=p->Next)
{
int son=p->obj;
if (id1[node]==id1[son]) Dfs3(son,id2[node],dis2[node]+dep[son]-dep[node]);
else Dfs3(son,id1[node],dis1[node]+dep[son]-dep[node]);
}
}
int Jump(int x,int y)
{
int z=x;
for (int j=maxl-1; j>=0; j--)
{
int w=fa[z][j];
int d1=dep[w]-dep[y]+dis1[y];
int d2=dep[x]-dep[w]+dis1[x];
if ( dep[w]>dep[y] && ( d1>d2 || ( d1==d2 && id1[y]>id1[x] ) ) ) z=w;
}
return z;
}
int main()
{
freopen("3572.in","r",stdin);
freopen("3572.out","w",stdout);
scanf("%d",&n);
for (int i=1; i<=n; i++) head[i]=NULL;
for (int i=1; i<n; i++)
{
int u,v;
scanf("%d%d",&u,&v);
Add(u,v);
Add(v,u);
}
dep[1]=1;
Dfs1(1);
for (int j=1; j<maxl; j++)
for (int i=1; i<=n; i++)
fa[i][j]=fa[ fa[i][j-1] ][j-1];
scanf("%d",&q);
while (q--)
{
scanf("%d",&m);
for (int i=1; i<=m; i++) scanf("%d",&Node[i]),ori[i]=Node[i];
sort(Node+1,Node+m+1,Comp);
tail=0;
cnt=m;
sak[ ++tail ]=Node[1];
for (int i=2; i<=m; i++)
{
int p=Node[i];
int x=Lca(p,sak[tail]);
int last=0;
while ( dep[ sak[tail] ]>dep[x] && tail ) last=sak[tail],tail--;
if ( dep[ sak[tail] ]<dep[x] ) Fa[x]=sak[tail],sak[++tail]=x,Node[++cnt]=x;
if (last) Fa[last]=x;
Fa[p]=x;
sak[++tail]=p;
}
Fa[ sak[1] ]=0;
for (int i=1; i<=cnt; i++) head[ Node[i] ]=NULL;
int Min=Node[1];
for (int i=1; i<=cnt; i++)
{
int x=Node[i];
if (Fa[x]) Add(Fa[x],x);
if (dfn[x]<dfn[Min]) Min=x;
dis1[x]=dis2[x]=maxn;
if (i<=m) vis[x]=true; else vis[x]=false;
ans[x]=0;
}
Dfs2(Min);
Dfs3(Min,0,maxn);
for (int i=1; i<=cnt; i++)
{
int x=Node[i];
int y=Fa[x];
int p=id1[x];
int q=id1[y];
if (y)
{
int z=Jump(x,y);
ans[p]+=(Size[z]-Size[x]);
ans[q]-=Size[z];
}
else ans[p]+=(n-Size[x]);
ans[p]+=Size[x];
}
for (int i=1; i<=m; i++) printf("%d ",ans[ ori[i] ]);
printf("\n");
}
return 0;
}