原题
题目大意
给你一颗
n≤100000
个点的树,然后有
q≤100000
个询问,每个询问选定
k
个点,问,至少在树中删除多少个点,使得
保证
∑k≤100000
解题思路
首先可以确定的是,无解一定是两个点相邻。
其他的情况就要树形dp。
如果当前点为非选定点,且子树中有大于2个选定点,则这个点要删除。
如果这个点是选定点,且对于其中一个儿子的子树有选定点,那么就把这个儿子删除。
但是每次询问都把
n
个点做树形dp太耗时间,每次只要把
参考代码
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
#define maxn 100005
#define maxsq 20
#define mem(a,b) memset(a,b,sizeof(a))
using namespace std;
int head[maxn],t[maxn * 2],next[maxn * 2],sum;
int fr[maxn],to[maxn],tim;
int n,m,q;
int fa[maxn][maxsq],deep[maxn];
int a[maxn*2];
bool bz[maxn];
int stack[maxn];
int read(){
int ret=0,ff=1;
char ch=getchar();
while (ch<'0' || ch>'9') {
if (ch=='-') ff=-1;
ch=getchar();
}
while (ch>='0' && ch<='9') {
ret=ret*10+ch-'0';
ch=getchar();
}
return ret*ff;
}
bool cmp(int i,int j){
return fr[i]<fr[j];
}
void insert(int x,int y){
t[++sum]=y;
next[sum]=head[x];
head[x]=sum;
}
void dfs(int x,int father){
deep[x]=deep[father]+1;
fr[x]=++tim;
for(int tmp=head[x];tmp;tmp=next[tmp]) {
if (t[tmp]==father) continue;
fa[t[tmp]][0]=x;
dfs(t[tmp],x);
}
to[x]=tim;
}
int getlca(int x,int y){
if (deep[x]<deep[y]) swap(x,y);
fd(i,18,0)
if (deep[fa[x][i]]>=deep[y]) x=fa[x][i];
if (x==y) return x;
fd(i,18,0)
if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
int last[maxn];
int treedp(int w){
int ret=0,tot=0;
for(int tmp=head[w];tmp;tmp=next[tmp]){
ret+=treedp(t[tmp]);
tot+=last[t[tmp]];
}
if (bz[w]) {
ret=ret+tot;
last[w]=1;
}
else {
if (tot>1) ret++;
if (tot==1) last[w]=1;
else last[w]=0;
}
return ret;
}
int main(){
n=read();
fo(i,1,n-1) {
int x=read(),y=read();
insert(x,y);
insert(y,x);
}
dfs(1,0);
fa[1][0]=1;
fo(i,1,18)
fo(j,1,n) fa[j][i]=fa[fa[j][i-1]][i-1];
q=read();
while (q--) {
sum=0;
a[0]=0;
stack[0]=0;
m=read();
fo(i,1,m) {
int x=read();
a[++a[0]]=x;
bz[x]=1;
}
bool pd=0;
fo(i,1,m) {
int now=a[i];
if (fa[a[i]][0]!=a[i] && bz[fa[a[i]][0]]) {
pd=1;
break;
}
}
if (pd) {
fo(i,1,a[0]) bz[a[i]]=0;
puts("-1");
continue;
}
sort(a+1,a+a[0]+1,cmp);
fo(i,1,m-1) a[++a[0]]=getlca(a[i],a[i+1]);
sort(a+1,a+a[0]+1,cmp);
int nowtot=0,last=0;
fo(i,1,a[0]) {
if (a[i]==last) continue;
a[++nowtot]=a[i];
last=a[i];
}
a[0]=nowtot;
fo(i,1,a[0]) {
head[a[i]]=0;
while (stack[0]>0 && !(fr[stack[stack[0]]]<=fr[a[i]] && fr[a[i]]<=to[stack[stack[0]]]))
stack[0]--;
if (stack[0]>0) insert(stack[stack[0]],a[i]);
stack[++stack[0]]=a[i];
}
printf("%d\n",treedp(a[1]));
fo(i,1,a[0]) bz[a[i]]=0;
}
return 0;
}