题目大意:给出一棵带边权的树,有 m 次询问,每次询问给出 k 个点,问使得这 k 个点和 根节点不连通的最小代价。
n ≤ 250000 , m ≥ 1 , ∑ k ≤ 500000 n \leq 250000,m \geq 1,\sum k\leq 500000 n≤250000,m≥1,∑k≤500000
如果暴力 dp,复杂度是 O ( n ∗ m ) O(n * m) O(n∗m),难以承受。
建立虚树,维护一下每个点到根节点路径上的最小边权,在虚树上进行DP。
代码:
#include<bits/stdc++.h>
using namespace std;
#define pii pair<int,int>
#define fir first
#define sec second
const int maxn = 3e5 + 10;
const int inf = 0x3f3f3f3f;
typedef long long ll;
int n,m,k;
vector<pii> g[maxn];
vector<int> h[maxn];
ll val[maxn];
int p[maxn][20],w[maxn][20],dep[maxn],dfn[maxn],cnt;
int sta[maxn],top,a[maxn],is[maxn];
inline int read() {
char c = getchar(); int x = 0, f = 1;
while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x * f;
}
void prework(int u,int fa) {
if(u == 1) val[u] = 1ll << 60,dep[u] = 1;
dfn[u] = ++cnt;
for(int i = 1; i <= 18; i++) {
p[u][i] = p[p[u][i - 1]][i - 1];
w[u][i] = min(w[p[u][i - 1]][i - 1],w[u][i - 1]);
}
for(auto it : g[u]) {
if(it.fir == fa) continue;
p[it.fir][0] = u;
w[it.fir][0] = it.sec;
val[it.fir] = min(val[u],1ll*it.sec);
dep[it.fir] = dep[u] + 1;
prework(it.fir,u);
}
}
int cmp(int a,int b) {
return dfn[a] < dfn[b];
}
int getlca(int x,int y) {
if(dep[x] < dep[y]) swap(x,y);
for(int i = 18; i >= 0; i--) {
if(dep[p[x][i]] >= dep[y])
x = p[x][i];
}
if(x == y) return x;
for(int i = 18; i >= 0; i--) {
if(p[x][i] != p[y][i]) {
x = p[x][i];
y = p[y][i];
}
}
return p[x][0];
}
void insert(int x) { //构建虚树
int lca = getlca(x,sta[top]);
while(top > 1 && dfn[sta[top - 1]] >= dfn[lca]) {
h[sta[top - 1]].push_back(sta[top]);
h[sta[top]].push_back(sta[top - 1]);
top--;
}
if(sta[top] != lca) {
h[lca].push_back(sta[top]);
h[sta[top]].push_back(lca);
sta[top] = lca;
}
sta[++top] = x;
}
ll dfs(int u,int fa) {
ll sum = 0;
for(auto it : h[u]) {
if(it == fa) continue;
sum += dfs(it,u);
}
h[u].clear();
if(is[u]) {
is[u] = 0;
return 1ll * val[u];
}
else return min(sum,1ll * val[u]);
}
int main() {
n = read();
for(int i = 1,u,v,w; i < n; i++) {
u = read();v = read();w = read();
g[u].push_back(pii(v,w));
g[v].push_back(pii(u,w));
}
prework(1,0);
m = read();
while(m--) {
top = 0;
k = read();
for(int i = 1; i <= k; i++)
a[i] = read(),is[a[i]] = 1;
sort(a + 1,a + k + 1,cmp);
sta[++top] = 1;
for(int i = 1; i <= k; i++)
insert(a[i]);
for(int i = 1; i < top; i++) {
h[sta[i]].push_back(sta[i + 1]);
h[sta[i + 1]].push_back(sta[i]);
}
printf("%lld\n",dfs(1,0));
}
return 0;
}