题面
解法
比较经典的树形dp
- 我们设 fi,j f i , j 表示点 i i 向下深度为还需要覆盖,深度为 j j 以下的点已经被完全覆盖的最小代价,表示点 i i 这棵子树已经被全部覆盖,并且还能向上延伸层的最小代价
- 那么,我们可以写出下列转移方程:
- fi,j=∑k∈sonifk,j−1 f i , j = ∑ k ∈ s o n i f k , j − 1 , gi,j=min(gi,j+fk,j,gk,j+1+fi,j+1) g i , j = m i n ( g i , j + f k , j , g k , j + 1 + f i , j + 1 )
- 现在解释一下这个状态转移方程式如何得到的
- f f 的转移方程应该不必赘述,主要是的转移方程
- 还没有更新的 gi,j g i , j 表示在 k k 前面的儿子能扩展到以上的 j j 层。那么,这棵子树里可以由前面的儿子来帮忙扩展一部分,即深度为 j j ,那么就得到了
- 还没更新的 f f 表示前面的儿子最低深度为 j j ,可以考虑由来使得 i i 能扩展到它以上的层,前面的那些子树就可以由 k k 来覆盖,那么我们就得到的
【注意事项】
- 一定是先更新 g g 再更新
- 可能会出现向上恰好覆盖 j j 层的代价小于恰好向上覆盖层的代价,这显然是会出现的,所以我们就要对状态稍微更改一下,把 fi,j f i , j 表示成最多覆盖 j j 层的最小代价,表示成至少覆盖 j j 层的最小代价就可以解决这个问题了
- 对做后缀最小值, g g 做前缀最小值就可以了
时间复杂度:
代码
#include <bits/stdc++.h>
#define N 500010
using namespace std;
template <typename node> void chkmax(node &x, node y) {x = max(x, y);}
template <typename node> void chkmin(node &x, node y) {x = min(x, y);}
template <typename node> void read(node &x) {
x = 0; int f = 1; char c = getchar();
while (!isdigit(c)) {if (c == '-') f = -1; c = getchar();}
while (isdigit(c)) x = x * 10 + c - '0', c = getchar(); x *= f;
}
struct Edge {
int next, num;
} e[N * 3];
int n, d, cnt, w[N], vis[N], f[N][21], g[N][21];
void add(int x, int y) {
e[++cnt] = (Edge) {e[x].next, y};
e[x].next = cnt;
}
void dfs(int x, int fa) {
f[x][0] = g[x][0] = vis[x] * w[x];
for (int i = 1; i <= d; i++) g[x][i] = w[x];
g[x][d + 1] = 1ll << 30;
for (int p = e[x].next; p; p = e[p].next) {
int k = e[p].num;
if (k == fa) continue; dfs(k, x);
for (int i = 0; i <= d; i++) g[x][i] = min(f[k][i] + g[x][i], f[x][i + 1] + g[k][i + 1]);
for (int i = d; i >= 0; i--) g[x][i] = min(g[x][i], g[x][i + 1]); f[x][0] = g[x][0];
for (int i = 1; i <= d; i++) f[x][i] += f[k][i - 1];
for (int i = 1; i <= d; i++) f[x][i] = min(f[x][i], f[x][i - 1]);
}
}
int main() {
read(n), read(d); cnt = n;
for (int i = 1; i <= n; i++) read(w[i]);
int m; read(m);
for (int i = 1; i <= m; i++) {
int x; read(x);
vis[x] = 1;
}
for (int i = 1; i < n; i++) {
int x, y; read(x), read(y);
add(x, y), add(y, x);
}
dfs(1, 0); cout << f[1][0] << "\n";
return 0;
}