题意:
给一棵带权树,1为根,q次询问,每次询问给出k个特殊点,要求去掉一些边,代价是权值,让点1不能到达k里的任意一个,输出最小代价。
题解:
首先考虑单次询问,容易想到dp[x]表示让1不能到x子树里的特殊点的代价。
当x是特殊点的时候,显然有dp[x]= min(path[x]),表示x到根节点的最小的那条边,因为一定要去掉x。
当x不是的时候,dp[x]=min(path[x], sum(dp[son(x)])),表示要么去掉x,要么去掉x所有儿子的子树的特殊点。
答案就是dp[1]。
现在变成多次询问,但是总的特殊点不超过
105
,需要用到虚树的技巧。
看了几篇blog学了一下虚树,这篇的写法很不错。
照着思路写就行了,代码还是比较容易的。
#include<bits/stdc++.h>
using namespace std;
const int N = 3e5+5;
typedef long long ll;
typedef pair<int,ll> edg;
const ll inf = ~0ull>>5;
vector<edg>G[N];
vector<int>qry, GG[N];
ll cost[N];
int dep[N], fa[N][25], in[N], cnt = 0;
int st[N], top;
bool vis[N];
inline bool cmp(const int& a, const int& b){
return in[a] < in[b];
}
void dfs1(int rt, int f){
in[rt] = ++cnt;
for(int i = 1; i < 25; ++i){
if(dep[rt] < (1<<i)) break;
fa[rt][i] = fa[fa[rt][i-1]][i-1];
}
for(int i = 0; i < G[rt].size(); ++i){
int v = G[rt][i].first;
if(v == f) continue;
cost[v] = min(cost[rt], G[rt][i].second);
dep[v] = dep[rt]+1;
fa[v][0] = rt;
dfs1(v, rt);
}
}
int lca(int a, int b){
if(dep[a] < dep[b]) swap(a, b);
int delt = dep[a] - dep[b];
for(int i = 0; i < 25; ++i){
if(delt&(1<<i)) a = fa[a][i];
}
for(int i = 24; i >= 0; --i){
if(fa[a][i] != fa[b][i]) a = fa[a][i], b = fa[b][i];
}
if(a != b) return fa[a][0];
else return a;
}
void add(int a, int b){
if(a == b) return;
GG[a].push_back(b);
}
ll dp[N];
void solve(int rt){
dp[rt] = cost[rt];
if(!vis[rt]){
ll sum = 0;
for(int i = 0; i < GG[rt].size(); ++i){
solve(GG[rt][i]);
sum += dp[GG[rt][i]];
}
dp[rt] = min(dp[rt], sum);
}
}
int main(){
int n;
scanf("%d", &n);
for(int a, b, c, i = 1; i < n; ++i){
scanf("%d%d%d", &a, &b, &c);
G[a].push_back(edg(b, c));
G[b].push_back(edg(a, c));
}
cost[1] = inf;
dfs1(1, 1);
int q;
scanf("%d", &q);
while(q--){
qry.clear();
int k;
scanf("%d", &k);
for(int a, i = 0; i < k; ++i){
scanf("%d", &a);
qry.push_back(a);
vis[a] = 1;
}
sort(qry.begin(), qry.end(), cmp);
int len = qry.size();
for(int i = 1; i < len; ++i){
qry.push_back(lca(qry[i], qry[i-1]));
}
qry.push_back(1);
sort(qry.begin(), qry.end(), cmp);
qry.erase(unique(qry.begin(), qry.end()), qry.end());
int top = 0;
for(int i = 0; i < qry.size(); ++i){
while(top > 0 && lca(st[top], qry[i]) != st[top]) top--;
if(top) add(st[top], qry[i]);
st[++top] = qry[i];
}
solve(1);
for(int i = 0; i < qry.size(); ++i) GG[qry[i]].clear(), vis[qry[i]] = 0;
printf("%lld\n", dp[1]);
}
}