题意
一棵有n个节点的无向树,根节点为1,告诉你k个关键节点,现在要割断一些边使得这k个点与根节点不连通,求割掉的边的最小边权和。
思路
记录dis[v]表示点1到点v的路径中最小的边权。
如果u是关键节点,dp[u] = dis[u],否则dp[u] += min(dis[v], dp[v])
使用虚树优化。虚树仅仅记录了关键节点和他们的lca
代码
#include <bits/stdc++.h>
#define inf 0x3f3f3f3fll
using namespace std;
const int N = 250500;
struct edge {
int to, next, w;
} e[N << 1], e_[N << 1];
int n, head[N], cnt, head_[N], cnt_;
int h[N];
bool vis[N];
void add(int u, int v, int w) {
e[cnt] = (edge){v, head[u], w};
head[u] = cnt++;
}
void add_(int u, int v) {
e_[cnt_] = (edge){v, head_[u], 0};
head_[u] = cnt_++;
}
// 调用init定根初始化
namespace LCA {
int log[N], depth[N], par[N][31], dfn[N];
long long dis[N];
int tot;
inline void pre() {
log[1] = 0;
for (int i = 2; i < N; ++i) {
log[i] = log[i - 1];
if ((1 << (log[i] + 1)) == i) log[i]++;
}
}
void dfs(int u, int fa, int dep) {
depth[u] = dep;
par[u][0] = fa;
dfn[u] = ++tot;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].to;
if (v == fa) continue;
dis[v] = min(dis[u], 1ll*e[i].w);
dfs(v, u, dep + 1);
}
}
inline void work() {
for (int j = 1; j <= log[n]; ++j)
for (int i = 1; i <= n; ++i)
par[i][j] = par[par[i][j - 1]][j - 1];
}
inline int lca(int u, int v) {
if (depth[u] < depth[v]) std::swap(u, v);
int t = depth[u] - depth[v];
for (int j = 0; j <= log[n]; ++j)
if (t >> j & 1) u = par[u][j];
if (u == v) return u;
for (int i = log[n]; ~i; --i)
if (par[u][i] != par[v][i]) {
u = par[u][i];
v = par[v][i];
}
return par[u][0];
}
inline void init(int root) {
pre();
dfs(root, -1, 0);
work();
}
} // namespace LCA
int top, stk[N << 1];
void insert(int u) {
if(u == 1) return;
if(top == 1) {
stk[++top] = u;
return;
}
int t = LCA::lca(u,stk[top]);
if(t == stk[top]) {
stk[++top] = u;
return;
}
while(top > 1 && LCA::dfn[stk[top-1]] >= LCA::dfn[t]) {
add_(stk[top-1], stk[top]);
top--;
}
if(t!=stk[top]) {
add_(t, stk[top]);
stk[top] = t;
}
stk[++top] = u;
}
long long dfs(int u) {
long long sum = 0;
for (int i = head_[u]; ~i; i = e_[i].next) {
int v = e_[i].to;
sum += min(LCA::dis[v], dfs(v));
}
head_[u] = -1;
if (vis[u]) {
vis[u] = false;
return LCA::dis[u];
}
else return sum;
}
bool cmp(int i, int j) { return LCA::dfn[i] < LCA::dfn[j]; }
int main() {
scanf("%d", &n);
memset(head, -1, sizeof head);
for (int i = 1; i < n; ++i) {
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
add(u, v, w);
add(v, u, w);
}
LCA::dis[1] = inf;
LCA::init(1);
int m;
scanf("%d", &m);
memset(head_, -1, sizeof head_);
while (m--) {
int k;
scanf("%d", &k);
for (int i = 1; i <= k; ++i) {
scanf("%d", h + i);
vis[h[i]] = true;
}
cnt_ = 0;
top = 0;
stk[++top] = 1;
sort(h + 1, h + 1 + k, cmp);
for (int i = 1; i <= k; ++i) insert(h[i]);
while (top > 1) {
add_(stk[top - 1], stk[top]);
--top;
}
printf("%lld\n", dfs(1));
}
return 0;
}