【题目链接】
【前置技能】
- 树形DP
- 虚树
- 树上倍增
【题解】
- 关键点的数量和与 N N 同阶,那么建虚树,剩下来的问题就在于树形DP了。
- 先两次树形DP求出虚树上每个点的最近关键点和它们之间的距离。第一次DP自叶节点向上,考虑以该节点为根的子树中离该节点最近的关键点。第二次DP自根向下,考虑祖先节点的最近关键点对于该节点的影响。
- 然后再遍历虚树上的每一条边,虚树边上的点的最近关键点只可能是这条边两端节点的最近关键点之一,倍增找出关键点控制范围的分界点。
- size s i z e 记录的是以该节点为根的子树在原树中的大小。统计答案时,先认为该节点的最近关键点能够控制以该节点为根的子树中的所有点,将 size[pos] s i z e [ p o s ] 统计入答案。在枚举该节点连向儿子 son s o n 的虚树边的时候,将分界点的子树大小 size[d] s i z e [ d ] 减去,将 d d 到之间的点的数量 size[d]−size[son] s i z e [ d ] − s i z e [ s o n ] 统计入 son s o n 的最近关键点的答案中。最后注意一下虚树的根和 1 1 号节点之间的点要统计入答案。
- 时间复杂度
【代码】
#include<bits/stdc++.h>
#define INF 0x3f3f3f3f
#define LL long long
#define MAXN 300010
#define MAXLOG 22
using namespace std;
int n, Q, k, q[MAXN], p[MAXN];
int ans[MAXN], dis[MAXN], f[MAXN];
int dfn[MAXN], tim, dep[MAXN], used[MAXN], cnt, sta[MAXN], top, mark[MAXN], root, fa[MAXN][MAXLOG + 2], size[MAXN];
vector <int> a[MAXN], b[MAXN];
template <typename T> void chkmin(T &x, T y){x = min(x, y);}
template <typename T> void chkmax(T &x, T y){x = max(x, y);}
template <typename T> void read(T &x){
x = 0; int f = 1; char ch = getchar();
while (!isdigit(ch)) {if (ch == '-') f = -1; ch = getchar();}
while (isdigit(ch)) {x = x * 10 + ch - '0'; ch = getchar();}
x *= f;
}
bool cmp(int a, int b){
return dfn[a] < dfn[b];
}
void add(int dad, int son){
b[dad].push_back(son);
}
int lca(int u, int v){
if (dep[u] > dep[v]) swap(u, v);
for (int i = MAXLOG; i >= 0; --i)
if (dep[fa[v][i]] >= dep[u]) v = fa[v][i];
if (u == v) return u;
for (int i = MAXLOG; i >= 0; --i)
if (fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i];
return fa[u][0];
}
void build(){
sort(q + 1, q + 1 + k, cmp);
root = lca(q[1], q[k]);
int beg = 1 + (q[1] == root);
sta[top = 1] = used[cnt = 1] = root;
for (int i = beg; i <= k; ++i){
int low = lca(q[i], sta[top]);
if (low == sta[top]) sta[++top] = q[i];
else {
while (dep[low] < dep[sta[top - 1]]){
add(sta[top - 1], sta[top]);
used[++cnt] = sta[top--];
}
if (low == sta[top - 1]) {
add(sta[top - 1], sta[top]);
used[++cnt] = sta[top--];
sta[++top] = q[i];
} else {
add(low, sta[top]);
used[++cnt] = sta[top--];
sta[++top] = low;
sta[++top] = q[i];
}
}
}
while (top >= 2){
add(sta[top - 1], sta[top]);
used[++cnt] = sta[top--];
}
}
void dfs(int pos, int dad){
dfn[pos] = ++tim;
dep[pos] = dep[dad] + 1;
fa[pos][0] = dad;
size[pos] = 1;
for (int i = 1; (1 << i) <= dep[pos]; ++i)
fa[pos][i] = fa[fa[pos][i - 1]][i - 1];
for (unsigned i = 0, si = a[pos].size(); i < si; ++i){
int son = a[pos][i];
if (son != dad){
dfs(son, pos);
size[pos] += size[son];
}
}
}
void dpI(int pos){
if (mark[pos]) dis[pos] = 0, f[pos] = pos;
for (unsigned i = 0, si = b[pos].size(); i < si; ++i){
int son = b[pos][i], wth = dep[son] - dep[pos];
dpI(son);
if (wth + dis[son] < dis[pos]) dis[pos] = wth + dis[son], f[pos] = f[son];
else if (wth + dis[son] == dis[pos] && f[son] < f[pos]) f[pos] = f[son];
}
}
void dpII(int pos){
for (unsigned i = 0, si = b[pos].size(); i < si; ++i){
int son = b[pos][i], wth = dep[son] - dep[pos];
if (dis[pos] + wth < dis[son]) dis[son] = wth + dis[pos], f[son] = f[pos];
else if (dis[pos] + wth == dis[son] && f[pos] < f[son]) f[son] = f[pos];
dpII(son);
}
}
int get(int pos, int len){
int d = 0;
while (len){
if (len & 1) pos = fa[pos][d];
len >>= 1;
++d;
}
return pos;
}
void getans(int pos){
ans[f[pos]] += size[pos];
for (unsigned i = 0, si = b[pos].size(); i < si; ++i){
int son = b[pos][i], wth = dep[son] - dep[pos], d = son;
getans(son);
if (wth == 1) d = son;
else if (f[son] == f[pos]) d = pos;
else {
int x = (dis[pos] - dis[son] + wth) / 2;
d = get(son, x);
if (dis[son] + x == dis[pos] + wth - x && f[son] > f[pos]) d = get(son, x - 1);
}
ans[f[son]] += size[d] - size[son];
ans[f[pos]] -= size[d];
}
}
int main(){
read(n);
for (int i = 1; i < n; ++i){
int u, v; read(u), read(v);
a[u].push_back(v);
a[v].push_back(u);
}
dfs(1, 0);
read(Q);
memset(dis, INF, sizeof(dis));
while (Q--) {
read(k);
for (int i = 1; i <= k; ++i)
read(q[i]), mark[q[i]] = 1;
for (int i = 1; i <= k; ++i)
p[i] = q[i];
build();
dpI(root);
dpII(root);
getans(root);
if (root != 1) ans[f[root]] += size[1] - size[root];
for (int i = 1; i <= k; ++i)
printf("%d%c", ans[p[i]], " \n"[i == k]);
for (int i = 1; i <= cnt; ++i)
b[used[i]].clear(), dis[used[i]] = INF, f[used[i]] = 0;
for (int i = 1; i <= k; ++i)
mark[q[i]] = 0, ans[q[i]] = 0;
}
return 0;
}