上一期的cf上出了一道虚树的题目,“虚树”一直听别人讲,但自己始终没有去学习,于是去这里学习了下,还是比较简单易懂的。大概就是说,针对一类每次询问树上部分点的信息的问题,我们可以把被询问的点单独拿出来,为了维护这些点的相对位置,我们找到一个点数最少的关于点的子树(相当于缩掉那些无关紧要的点),然后再在这棵树上做操作。可以证明,这棵树的大小将是
O(询问点的个数)
,这颗树就叫做虚树。
构造虚树就是关键咯。首先将点按照dfs序排序,然后我们把点依次添加进虚树当中去。添加的核心就在于维护从根到当前点的dfs链,使用一个栈维护。可以在脑中模拟一下dfs的过程有助于理解几种情况的应对策略。另外一个关键的地方就是每次在点出栈的时候加边,因为当点出栈的时候说明这个点的子树已经完全访问完毕。附上这题的代码,第一次写,虽然通过了题目,但还是可能有漏洞,另外可以做一下哪个网站的虚树专题,还有这题,虽然这题和虚树没什么关系。
#include<bits/stdc++.h>
using namespace std;
const int Maxn=100020,Inf=1e9;
int n;
vector<int>G[Maxn];
vector<int>G2[Maxn];
vector<int>occ;
int pre[Maxn];
int dfs_t;
int flag;
int son[Maxn],f[Maxn],h[Maxn],sz[Maxn],low[Maxn];
int is[Maxn];
void dfs(int u,int p)
{
f[u]=p;
h[u]=h[p]+1;
sz[u]=1;
for(int i=0;i<G[u].size();i++)
{
int v=G[u][i];if(v==p)continue;
dfs(v,u);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]])son[u]=v;
}
}
void dfs2(int u,int p,int Low)
{
pre[u]=++dfs_t;
low[u]=Low;
if(son[u])dfs2(son[u],u,Low);
for(int i=0;i<G[u].size();i++)
{
int v=G[u][i];if(v==p||v==son[u])continue;
dfs2(v,u,v);
}
}
int getlca(int u,int v)
{
while(low[u]!=low[v])
{
if(h[low[u]]>h[low[v]])u=f[low[u]];
else v=f[low[v]];
}
return h[u]>h[v]?v:u;
}
bool cmp(int a,int b){return pre[a]<pre[b];}
void add(int u,int v)
{
if(is[u]&&is[v])
{
if(f[v]==u){flag=0;return;}
G2[u].push_back(f[v]);
G2[f[v]].push_back(v);
occ.push_back(u);
occ.push_back(f[v]);
}
else
{
occ.push_back(u);
G2[u].push_back(v);
}
}
int dp[Maxn][2][2],rep[Maxn];
int sta[Maxn];
inline void upd(int &x,int y){if(x>y)x=y;}
void solve(int u)
{
int tp[2][2];
int cs=0;
dp[u][1][0]=is[u]?Inf:1;
tp[0][0]=is[u]?Inf:0;
tp[0][1]=is[u]?0:Inf;
for(int i=0;i<G2[u].size();i++,cs^=1)
{
int v=G2[u][i];solve(v);
tp[cs^1][0]=tp[cs^1][1]=Inf;
dp[u][1][0]+=rep[v];
upd(dp[u][1][0],Inf);
upd(tp[cs^1][0],tp[cs][0]+min(dp[v][1][0],dp[v][0][0]));
upd(tp[cs^1][1],tp[cs][1]+min(dp[v][1][0],dp[v][0][0]));
upd(tp[cs^1][1],tp[cs][0]+dp[v][0][1]);
}
dp[u][0][0]=tp[cs][0];
dp[u][0][1]=tp[cs][1];
rep[u]=dp[u][1][0];
upd(rep[u],dp[u][0][0]);
upd(rep[u],dp[u][0][1]);
//printf("u=%d\n",u);
//printf("%d %d %d\n",dp[u][1][0],dp[u][0][0],dp[u][0][1]);
}
int main()
{
scanf("%d",&n);
for(int i=1;i<n;i++)
{
int u,v;scanf("%d%d",&u,&v);
G[u].push_back(v);
G[v].push_back(u);
}
dfs(1,0);
dfs2(1,0,1);
int q;scanf("%d",&q);
while(q--)
{
int k;scanf("%d",&k);
vector<int>V;
for(int i=0;i<k;i++)
{
int x;scanf("%d",&x);
is[x]=1;
V.push_back(x);
}
if(!is[1]){V.push_back(1);}
sort(V.begin(),V.end(),cmp);
//for(int i=0;i<V.size();i++)printf("%d ",V[i]);puts("");
int top=0;
sta[++top]=1;
occ.clear();
flag=1;
for(int i=1;i<V.size();i++)
{
int x=V[i];
int lca=getlca(x,sta[top]);
if(lca!=sta[top])
{
while(pre[sta[top-1]]>pre[lca])
{
add(sta[top-1],sta[top]);
top--;
}
add(lca,sta[top]);
top--;
if(lca!=sta[top])sta[++top]=lca;
}
sta[++top]=x;
}
while(top>1){add(sta[top-1],sta[top]);top--;}
if(flag)solve(1);
for(int i=0;i<occ.size();i++)G2[occ[i]].clear();
for(int i=0;i<V.size();i++)is[V[i]]=0;
if(rep[1]!=Inf&&flag)printf("%d\n",rep[1]);
else puts("-1");
}
}