第一次做虚树的题…
题意:给出一棵树,多次询问,每次给出k[i]个点,询问将这些点从树上分离开来最少需要删除多少个点,保证k[i]的和不超过100000.
我们先建虚树,然后在虚树上Dp就可以了,我们设Dp[i][0/1],若为0表示这个点子树中所有关键点与这个点都断开了,1表示还有1个关键点连在这个点上。(不可能有大于等于2个点,不然就不合法了)。
那么若当前点是关键点,
Dp[i][0]=inf,
Dp[i][1]=sigma(min(Dp[son][0] ,Dp[son][1]+1)),但是如果当前这条边的边权为1,就只能用Dp[son][0]转移得到。
如果是非关键点:
S1=sigma(min(Dp[son][0],Dp[son][1]))+1,不管怎么样,我们都直接把这个点给删去,不考虑全部是Dp[son][0]的情况。
S2=sigma(min(Dp[son][0],Dp[son][1]+1)),表示不删这个点,而删与儿子相连的虚边上的点,同上,如果当前这条边的边权为1,就只能用Dp[son][0]转移得到。
首先如果Dp[i][0]全部由Dp[son][0]转来,那么就是S2,否则S1一定比S2更优。
所以Dp[i][0]=min(S1,S2).
然后考虑Dp[i][1],那么就是在Dp[i][0]的条件上减去某个min(Dp[son][0],Dp[son][1]+1),再加上某个Dp[son][1]就可以了(这样可以保证只有一个关键点和当前点i相连),于是我们在记录S1,S2的时候顺便记录一个min : Dp[son][1] -min(Dp[son][0],Dp[son][1]+1),还是同上, 如果当前这条边的边权为1,就只能用Dp[son][0]转移得到。
然后min(Dp[root][0],Dp[root][1])就是答案啦。
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <string>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <ctime>
#define inf (long long)100000000
using namespace std;
struct node {int to;int next;int len;
};node edge[500010],bian[2000010];
int in[200010],out[200010],dep[200010],f[200010][21],root;
int tim = 0,fir[200010],first[200010],n,a,b,m,Q,stack[200010];
int sum=0,size=0,len,top,list[200010],que[200010];
long long Dp[200010][3];
bool instack[200010],mark[200010];
bool comp(const int &x,const int &y) {return in[x] < in[y];}
bool check(int x,int y) {return in[x] <= in[y] && out[x] >= out[y];}
void add(int x,int y) {
edge[ ++ sum].to = y;
edge[sum].next = fir[x];
fir[x] = sum;
}
void inser(int x,int y,int z) {
bian[ ++ size].to = y;
bian[size].next = first[x];
first[x] = size;
bian[size].len = z;
}
void dfs(int x,int depth,int Anc) {
dep[x] = depth;
f[x][0] = Anc;
in[x] = ++ tim;
for(int i = 1;i <= 20;i ++)
f[x][i] = f[f[x][i - 1]][i - 1];
for(int u = fir[x];u;u = edge[u].next)
if(dep[edge[u].to] == 0)
dfs(edge[u].to,depth + 1,x);
out[x] = ++ tim;
}
int lca(int x,int y) {
if(dep[x] < dep[y]) swap(x,y);
for(int i = 20;i >= 0;i --)
if(dep[f[x][i]] >= dep[y])
x = f[x][i];
if(x == y) return x;
for(int i = 20;i >= 0;i --)
if(f[x][i] != f[y][i])
x = f[x][i],y = f[y][i];
return f[x][0];
}
void dp(int x) {
for(int u = first[x];u;u = bian[u].next)
if(instack[bian[u].to])
dp(bian[u].to);
if(mark[x] == false) {
long long s1 = 0,s2 = 0;
long long minx = 2 * inf;
for(int u = first[x];u;u = bian[u].next)
if(instack[bian[u].to])
{
long long m1 = min(Dp[bian[u].to][0],Dp[bian[u].to][1]);
long long m2;
if(bian[u].len > 1)
m2 = min(Dp[bian[u].to][0],Dp[bian[u].to][1] + 1);
else m2 = min(Dp[bian[u].to][0],inf);
s1 += m1;
s2 += m2;
minx = min(minx,Dp[bian[u].to][1] - m2);
}
Dp[x][0] = min(s1 + 1,s2);
Dp[x][1] = min(inf,s2 + minx);
}
if(mark[x] == true) {
long long s = 0;
Dp[x][0] = inf;
for(int u = first[x];u;u = bian[u].next)
if(instack[bian[u].to])
{
long long c;
if(bian[u].len > 1) c = Dp[bian[u].to][1] + 1;
else c = inf;
s += min(c,Dp[bian[u].to][0]);
}
if(s >= inf) s = inf;
Dp[x][1] = s;
return ;
}
}
int main() {
scanf("%d",&n);
for(int i = 1;i <= n - 1;i ++) {
scanf("%d%d",&a,&b);
add(a,b);
add(b,a);
}
dfs(1,1,1);
scanf("%d",&Q);
while(Q --) {
scanf("%d",&m);len = m;top = 0;root = 0;
for(int i = 1;i <= m;i ++) scanf("%d",&list[i]),que[i] = list[i];
sort(list + 1,list + m + 1,comp);
for(int i = 1;i < m;i ++)
list[++ len] = lca(list[i],list[i + 1]);
sort(list + 1,list + len + 1,comp);
len = unique(list + 1,list + len + 1) - list - 1;
for(int i = 1;i <= len;i ++) {
while(top > 0 && !check(stack[top],list[i]))
top --;
if(stack[top] == 0) root = list[i];
else {
int s = stack[top],t = list[i];
inser(s,t,dep[t] - dep[s]);
}
stack[ ++ top] = list[i];
}
for(int i = 1;i <= len;i ++) instack[list[i]] = true;
for(int i = 1;i <= m;i ++) mark[que[i]] = true;
dp(root);
for(int i = 1;i <= len;i ++) instack[list[i]] = false;
for(int i = 1;i <= m;i ++) mark[que[i]] = false;
for(int i = 1;i <= len;i ++) first[list[i]] = 0;size = 0;
if(Dp[root][1] >= inf && Dp[root][0] >=inf) printf("-1\n");
else printf("%I64d\n",min(Dp[root][1],Dp[root][0]));
}
return 0;
}