题意:
给出一棵树,每次询问会给出几个关键点,要求选最少的非关键点使得把选的点去掉后关键点之间两两不能到达。
n≤100000
n
≤
100000
题解:
虚树模版。
dp的话就是设
f[x][0]
f
[
x
]
[
0
]
表示子树内关键点两两不连通,且没有点可以连到子树外,
f[x][1]
f
[
x
]
[
1
]
表示允许有点连到子树外。转移注意细节。
code:
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
#define LL long long
using namespace std;
struct node{
LL y,next;
}a[200010];LL len=0,last[100010];
LL n,dfn[100010],z=0;
struct trnode{
LL dep,fa[20];
}tr[100010];
void ins(LL x,LL y)
{
a[++len].y=y;
a[len].next=last[x];last[x]=len;
}
void dfs(LL x,LL fa)
{
tr[x].dep=tr[fa].dep+1;tr[x].fa[0]=fa;dfn[x]=++z;
for(LL i=1;(1<<i)<=tr[x].dep;i++)
tr[x].fa[i]=tr[tr[x].fa[i-1]].fa[i-1];
for(LL i=last[x];i;i=a[i].next)
if(a[i].y!=fa) dfs(a[i].y,x);
}
LL vis[100010],tim=0;
LL h[100010],tot,sta[100010];
bool cmp(LL x,LL y) {return dfn[x]<dfn[y];}
LL find_lca(LL x,LL y)
{
if(tr[x].dep<tr[y].dep) swap(x,y);
for(LL i=17;i>=0;i--)
if((1<<i)<=tr[x].dep-tr[y].dep) x=tr[x].fa[i];
if(x==y) return x;
for(LL i=17;i>=0;i--)
if((1<<i)<=tr[x].dep&&tr[x].fa[i]!=tr[y].fa[i]) x=tr[x].fa[i],y=tr[y].fa[i];
return tr[x].fa[0];
}
void build()
{
sort(h+1,h+tot+1,cmp);
LL top=0;sta[++top]=1;
for(LL i=1;i<=tot;i++)
{
LL lca=find_lca(h[i],sta[top]);
if(lca==sta[top])
{
if(sta[top]!=h[i]) sta[++top]=h[i];
continue;
}
while(top&&tr[sta[top-1]].dep>=tr[lca].dep) ins(sta[top-1],sta[top]),top--;
if(lca!=sta[top]) ins(lca,sta[top]),sta[top]=lca;
sta[++top]=h[i];
}
while(top>1) ins(sta[top-1],sta[top]),top--;
len=0;
}
const LL inf=1<<28;
LL f[100010][2];
void dp(LL x)
{
f[x][0]=f[x][1]=0;
for(LL i=last[x];i;i=a[i].next) dp(a[i].y);
if(vis[x]==tim)
{
f[x][0]=inf;
for(LL i=last[x];i;i=a[i].next)
f[x][1]+=min(f[a[i].y][0],f[a[i].y][1]+1);
}
else
{
LL Min=0;
for(LL i=last[x];i;i=a[i].next)
{
LL y=a[i].y;
f[x][0]+=min(f[y][0],f[y][1]);
f[x][1]+=f[y][0];
Min=min(Min,f[y][1]-f[y][0]);
}
f[x][0]=min(f[x][0]+1,f[x][1]);f[x][1]+=Min;
}
last[x]=0;
}
void solve()
{
tot;scanf("%lld",&tot);tim++;
for(LL i=1;i<=tot;i++) scanf("%lld",&h[i]),vis[h[i]]=tim;
for(LL i=1;i<=tot;i++) if(vis[tr[h[i]].fa[0]]==tim) {puts("-1");return;}
build();dp(1);
printf("%lld\n",min(f[1][1],f[1][0]));
}
int main()
{
scanf("%lld",&n);
for(LL i=1;i<n;i++)
{
LL x,y;scanf("%lld %lld",&x,&y);
ins(x,y);ins(y,x);
}
tr[0].dep=-1;dfs(1,0);
LL q;scanf("%lld",&q);
len=0;memset(last,0,sizeof(last));
while(q--) solve();
}