题意
给你很多个询问,每个询问给你很多个关键点,问至少删除多少个非关键点,可以使所有关键点两两不连通
分析
首先建出虚树,对于虚树treedp一下
f[i][0/1]
f
[
i
]
[
0
/
1
]
表示i子树内,关键点跟不跟外界联通
然后分一下当前点是不是关键点来转移就好了
代码
#include <bits/stdc++.h>
#define cl clear
#define pb push_back
using namespace std;
const int inf = 1e9;
const int N = 200010;
inline int read()
{
int p=0; int f=1; char ch=getchar();
while(ch<'0' || ch>'9'){if(ch=='-') f=-1; ch=getchar();}
while(ch>='0' && ch<='9'){p=p*10+ch-'0'; ch=getchar();}
return p*f;
}
struct node{int x,y,next;}edge[N]; int len,first[N];
void ins(int x,int y){len++; edge[len].x=x; edge[len].y=y; edge[len].next=first[x]; first[x]=len;}
int fa[N][21]; int dep[N]; int dfn[N],id;
void dfs(int x,int f)
{
dfn[x] = ++id;
for(int k=first[x];k!=-1;k=edge[k].next)
{
int y=edge[k].y;
if(y==f) continue;
fa[y][0] = x; dep[y] = dep[x] +1; dfs(y,x);
}
}
vector<int> v;
bool cmp(const int &x,const int &y){return dfn[x] < dfn[y];}
vector<int> g[N]; bool col[N];
void inss(int x,int y){g[x].pb(y);}
int f[N][2];
int dfs2(int x)
{
if(col[x])
{
f[x][1] = 0; f[x][0] = 1;
for(int i=0;i<g[x].size();i++)
{
int y=g[x][i];
dfs2(y);
f[x][1] += f[y][0]; f[x][0] += f[y][0];
}
}
else
{
f[x][1] = 0; f[x][0] = 1; int mx = inf; int s = 0;
for(int i=0;i<g[x].size();i++)
{
int y=g[x][i];
dfs2(y);
f[x][0] += min(f[y][1] , f[y][0]); s += f[y][0];
f[x][1] += f[y][0]; mx = min(f[y][1] - f[y][0],mx);
}f[x][0] = min(f[x][0] , s);
f[x][1] += mx;
}
}
int s[N],top = 0;
int LCA(int x,int y)
{
if(dep[x] < dep[y]) swap(x,y);
int deep = dep[x] - dep[y];
for(int i=18;i>=0;i--) if( deep >= (1<<i) ){deep -= (1<<i); x=fa[x][i];}
if(x==y) return x;
for(int i=18;i>=0;i--) if( fa[x][i] != fa[y][i] ){x=fa[x][i]; y=fa[y][i];}
return fa[x][0];
}
vector<int> vc;
int main()
{
len = 0; memset(first,-1,sizeof(first));
int n = read();
for(int i=1;i<n;i++){int x=read(); int y=read(); ins(x,y); ins(y,x);}
memset(dep,0,sizeof(dep)); dep[1] = 1; id = 0; dfs(1,0);
for(int j=1;j<=18;j++) for(int i=1;i<=n;i++) fa[i][j] = fa[fa[i][j-1]][j-1];
int q = read(); memset(col,0,sizeof(col));
while(q--)
{
int k = read();
v.cl(); for(int i=1;i<=k;i++) v.pb(read());
sort(v.begin(),v.end(),cmp); bool bk = 1;
for(int i=0;i<v.size();i++)
{
col[v[i]] = 1;
if(i && col[fa[v[i]][0]]){bk = 0; break;}
}
if(!bk)
{
printf("-1\n");
goto s1;
}
top = 1; s[top] = v[0]; vc.clear(); vc.pb(v[0]);
for(int i=1;i<v.size();i++)
{
if(v[i] == v[i-1]) continue;
int p = LCA(s[top],v[i]);
while(top && dfn[p] < dfn[s[top-1]]){inss(s[top-1],s[top]); top--;}
if(s[top] != p) inss(p,s[top]),top--;
if(s[top] != p) s[++top] = p,vc.pb(p);
s[++top] = v[i],vc.pb(v[i]);
}while(top>1){inss(s[top-1],s[top]); top--;}
// printf("%d\n",vc.size());
dfs2(s[1]);
printf("%d\n",min(f[s[1]][1],f[s[1]][0]));
s1:for(int i=0;i<v.size();i++) col[v[i]] = 0;
for(auto i:vc) g[i].cl(),f[i][0] = f[i][1] = 0;
}
return 0;
}