第一次接触虚树参考博客学习 http://blog.csdn.net/braketbn/article/details/50887470
虚树是根据需要查询的点以及他们的lca重新构建的一颗树,链之间的信息整合压缩后保存在新的边上。
然后这道题就是根据关键点和它们的lca构建虚树,然后做树上DP。
针对当前访问的点是否为关键点进行分类处理。
如果是关键点,如果其儿子节点的子树中存在单个关键点,那么删除该儿子节点。
如果不是关键点,若其子树中有还没有分离开的数量超过2个的关键节点,那么需要删除这个节点。
#include<bits/stdc++.h>
using namespace std;
const int maxn = 1e5+10;
struct Edge{
int to, nx;
}e[maxn*2];
bool imp[maxn];
int head[maxn], tot, in[maxn], out[maxn], clo;
int pre[maxn][25];
int dep[maxn];
int n, m;
bool cmp(int x, int y){
return in[x] < in[y];
}
inline void add(int from, int to){
e[++tot].to = to;
e[tot].nx = head[from];
head[from] = tot;
}
void dfs(int node, int p){
in[node] = ++clo;
pre[node][0] = p;
for (int i = 1; i <= 20; i++)
pre[node][i] = pre[pre[node][i-1]][i-1];
for (int k = head[node]; k != -1; k = e[k].nx){
if (e[k].to == p) continue;
dep[e[k].to] = dep[node]+1;
dfs(e[k].to, node);
}
out[node] = ++clo;
}
int node[maxn];
inline int dp(int x){
int cnt = 0, ans = 0;
for (int k = head[x]; k != -1; k = e[k].nx){
ans += dp(e[k].to);
cnt += node[e[k].to];
}
if (imp[x]){
ans += cnt;
node[x] = 1;
}
else{
ans += cnt > 1;
node[x] = cnt == 1;
}
return ans;
}
int lca(int x, int y){
if (dep[x] < dep[y]) swap(x, y);
for (int i = 20; i >= 0; i--){
if (dep[pre[x][i]] >= dep[y])
x = pre[x][i];
}
if (x == y) return x;
for (int i = 20; i >= 0; i--){
if (pre[x][i] != pre[y][i])
x = pre[x][i], y = pre[y][i];
}
return pre[x][0];
}
int main(){
std::ios::sync_with_stdio(false);
cin >> n;
memset(head, -1, sizeof(head));
for (int i = 1; i < n; i++){
int from, to;
cin >> from >> to;
add(from, to); add(to, from);
}
dep[1] = 1;
dfs(1, 1);
int T;
cin >> T;
memset(imp, false, sizeof(imp));
while(T--){
vector<int> vec;
int m;
cin >> m;
for (int i = 0; i < m; i++){
int x;
cin >> x;
vec.push_back(x);
imp[x] = true;
}
bool flag = false;
for (int i = 0; i < vec.size(); i++){
if (pre[vec[i]][0] != vec[i] && imp[pre[vec[i]][0]]){
flag = true;
break;
}
}
if (flag){
cout << -1 << endl;
memset(imp, false, sizeof(imp));
continue;
}
sort(vec.begin(), vec.end(), cmp);
for (int i = 1; i < m; i++)
vec.push_back(lca(vec[i], vec[i-1]));
sort(vec.begin(), vec.end(), cmp);
vec.resize(unique(vec.begin(), vec.end()) - vec.begin());
stack<int> stk; tot = 0;
for (int i = 0; i < vec.size(); i++){
int x = vec[i]; head[x] = -1;
while(!stk.empty() && !(in[stk.top()] <= in[x] && out[stk.top()] >= out[x])) stk.pop();
if (!stk.empty()){
add(stk.top(), x);
}
stk.push(x);
}
cout << dp(vec[0]) << endl;
memset(imp, false, sizeof(imp));
}
return 0;
}