题目
解题思路
建立虚树,在虚树上
d
p
dp
dp。
d
p
[
u
]
[
0
]
dp[u][0]
dp[u][0]表示只考虑
u
u
u的子树,
u
u
u子树里的关键点和
u
u
u都被断开的最小切断次数。
d
p
[
u
]
[
1
]
dp[u][1]
dp[u][1]表示只考虑
u
u
u的子树,
u
u
u子树里的关键点只有一个关键点和
u
u
u相连的最小切断次数。
这里需要注意的是,如果虚树上的父亲和实数上的父亲不同,则可以在虚树点和其虚树父亲间选个点切断。
最近公共祖先选择
d
f
s
dfs
dfs序上建立
S
T
ST
ST表,支持
O
(
n
l
o
g
2
n
)
O(nlog_2n)
O(nlog2n)预处理
O
(
1
)
O(1)
O(1)查询。
虚树建立使用了
s
o
r
t
(
)
sort()
sort()函数,复杂度
O
(
n
l
o
g
2
n
)
O(nlog_2n)
O(nlog2n)。如果选择线性排序,复杂度可到
O
(
n
)
O(n)
O(n)。
树形dp复杂度
O
(
n
)
O(n)
O(n)。
整体复杂度
O
(
n
l
o
g
2
n
)
O(nlog_2n)
O(nlog2n)。
代码
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <vector>
using namespace std;
typedef long long ll;
void read(int &x) {
x = 0; char c = getchar();
while (c < '0' || c > '9') c = getchar();
while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
}
void write(int x) {
if (x > 9) write(x / 10);
putchar(x % 10 + '0');
}
const int N = 4e5 + 100;
int n, tot;
int bg[N], ed[N], dep[N], st[20 + 1][N * 2], Log[N * 2], faz[N];
vector<int> V[N];
void dfs(int u, int fa) {
st[0][++tot] = u;
bg[u] = tot;
faz[u] = fa;
for (int v : V[u]) {
if (v == fa) continue;
dep[v] = dep[u] + 1;
dfs(v, u);
st[0][++tot] = u;
}
ed[u] = tot;
}
bool cmp1(int a, int b) {
return dep[a] < dep[b];
}
void init() {
dfs(1, 0);
for (int i = 1, t = 1, c = 0; i <= tot; i++) {
if (t * 2 == i) t *= 2, c++;
Log[i] = c;
}
for (int j = 0; j < 19; j++)
for (int i = 1; i <= tot; i++)
st[j + 1][i] = min(st[j][i], st[j][min(i + (1 << j), tot)], cmp1);
}
int lca(int a, int b) {
if (bg[a] > bg[b]) swap(a, b);
int g = Log[bg[b] - bg[a] + 1];
return min(st[g][bg[a]], st[g][bg[b] - (1 << g) + 1], cmp1);
}
int m, tp;
int has[N], sta[N];
bool cmp2(int a, int b) {
return bg[a] < bg[b];
}
ll dp[N][2];
bool flag[N];
void dfs1(int u, int fa) {
for (int v : V[u]) dfs1(v, u);
if (flag[u]) {
dp[u][0] = 1e9;
dp[u][1] = 0;
for (int v : V[u]) dp[u][1] += dp[v][0];
}
else {
ll sum1 = 0, sum2 = 0, res = -1e18;
for (int v : V[u]) {
sum1 += dp[v][0];
sum2 += min(dp[v][1], dp[v][0]);
res = max(res, dp[v][0] - dp[v][1]);
}
dp[u][0] = min(sum1, sum2 + 1);
dp[u][1] = min(dp[u][0], sum1 - res);
}
if (faz[u] != fa) dp[u][0] = min(dp[u][0], dp[u][1] + 1);
}
void clr(int u) {
for (int v : V[u]) clr(v);
V[u].clear();
flag[u] = false;
dp[u][0] = dp[u][1] = 0;
}
int main() {
//freopen("0.txt", "r", stdin);
int a, b, q;
read(n);
for (int i = 1; i < n; i++) {
read(a); read(b);
V[a].push_back(b);
V[b].push_back(a);
}
init();
for (int i = 1; i <= n; i++) V[i].clear();
read(q);
while (q--) {
read(m);
for (int i = 1; i <= m; i++) {
read(has[i]);
flag[has[i]] = true;
}
sort(has + 1, has + m + 1, cmp2);
sta[tp = 1] = has[1];
for (int i = 2; i <= m; i++) {
int g = lca(sta[tp], has[i]);
while (tp > 0 && dep[g] < dep[sta[tp]]) {
if (tp > 1 && dep[g] < dep[sta[tp - 1]]) V[sta[tp - 1]].push_back(sta[tp]);
else V[g].push_back(sta[tp]);
tp--;
}
if (tp == 0 || g != sta[tp]) sta[++tp] = g;
sta[++tp] = has[i];
}
while (tp > 1) V[sta[tp - 1]].push_back(sta[tp]), tp--;
dfs1(sta[1], 0);
ll ans = min(dp[sta[1]][0], dp[sta[1]][1]);
if (ans > 1e8) puts("-1");
else write(ans), puts("");
clr(sta[1]);
}
return 0;
}