P2495 [SDOI2011]消耗战
题解:
虚树\(dp\)入门题吧。虚树的核心思想其实就是每次只保留关键点,因为关键点的dfs序的相对大小顺序和原来的树中结点dfs序的相对大小顺序都是一样的,所以可以就求出dfs序并且利用它来构造。最后的图中只有关键点以及某些关键点对的lca。
具体构造方法就是利用一个栈,假设当前插入结点为\(x\),求出栈顶元素和\(x\)的lca,如果栈顶元素为lca,那么我们就继续延长这条链;否则(此时栈顶元素和\(x\)在lca的两颗子树上面)就将栈顶元素到\(lca\)上面的点弹出来并且建边,然后继续延长\(x\)这边的链。
感觉说得不是很清楚,详见代码吧:
#include <bits/stdc++.h>
#define INF 0x3f3f3f3f
using namespace std;
typedef long long ll;
const int N = 250005 ;
int n, m;
struct Edge{
int v, next, w;
}e[N << 1];
int head[N], tot;
void adde(int u, int v, int w) {
e[tot].v = v; e[tot].w = w; e[tot].next = head[u]; head[u] = tot++;
}
vector <int> g[N] ;
int f[N][20], deep[N], dfn[N];
ll mn[N];
int T;
void dfs(int u, int fa) {
deep[u] = deep[fa] + 1;
dfn[u] = ++T;
for(int i = head[u]; i != -1; i = e[i].next) {
int v = e[i].v;
if(v == fa) continue ;
mn[v] = min((ll)e[i].w, mn[u]) ;
f[v][0] = u;
for(int j = 1; j <= 17; j++) f[v][j] = f[f[v][j - 1]][j - 1] ;
dfs(v, u) ;
}
}
int LCA(int x, int y) {
if(deep[x] < deep[y]) swap(x, y) ;
for(int i = 17; i >= 0; i--) {
if(deep[f[x][i]] >= deep[y]) x = f[x][i] ;
}
if(x == y) return x;
for(int i = 17; i >= 0; i--) {
if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i] ;
}
return f[x][0] ;
}
int sta[N], a[N];
int top;
bool cmp(const int &x, const int &y) {
return dfn[x] < dfn[y] ;
}
void add_edge(int u, int v) {
g[u].push_back(v) ;
}
void insert(int x) {
int lca = LCA(x, sta[top]) ;
if(top == 1) {sta[++top] = x; return ;}
if(lca == sta[top]) return ;
while(top > 1 && dfn[sta[top - 1]] >= dfn[lca]) {
add_edge(sta[top - 1], sta[top]); top--;
}
if(sta[top] != lca) add_edge(lca, sta[top]), sta[top] = lca;
sta[++top] = x;
}
ll DP(int u) {
if(g[u].size() == 0) return mn[u];
ll sum = 0;
for(auto v : g[u]) sum += DP(v) ;
g[u].clear() ;
return min(sum, (ll)mn[u]) ;
}
int main() {
ios::sync_with_stdio(false); cin.tie(0);
cin >> n;
mn[1] = 1ll << 56;
memset(head, -1, sizeof(head)) ;
for(int i = 1; i < n; i++) {
int u, v, w;
cin >> u >> v >> w;
adde(u, v, w); adde(v, u, w) ;
}
dfs(1, 0) ;
cin >> m;
while(m--) {
int k; cin >> k;
for(int i = 1; i <= k; i++) cin >> a[i] ;
sort(a + 1, a + k + 1, cmp) ;
sta[top = 1] = 1;
for(int i = 1; i <= k; i++) insert(a[i]) ;
while(top > 1) add_edge(sta[top - 1], sta[top]), top--;
cout << DP(1) << '\n' ;
}
return 0 ;
}