Delta Quadrant
思路
大体来说就是考虑当去掉
k
k
k 个点后,所有路径之和加起来乘上
2
2
2 就是答案,因为要往返
为什么会想到是树形
d
p
dp
dp
我们先看数据范围,本来,如果是图论,那个
k
k
k 给的范围实在是很小,同时,这里存在一个最优方案,那么考虑,去除
k
k
k 个点后最短路,但是我们想,这时间复杂度必然爆炸,那么我们再看看,就会想到树形
d
p
dp
dp,时间复杂度也很合适
树形
d
p
dp
dp 的状态
- d p [ i ] [ j ] dp[i][j] dp[i][j] 的意思是以 i i i 点为根节点的子树上,有 j j j 个节点被删了
- c n t [ i ] cnt[i] cnt[i] 的意思是以 i i i 为根节点的树上有几个点
树形
d
p
dp
dp 的转移方程
大体来看
d
p
[
u
]
[
j
]
dp[u][j]
dp[u][j] 的构成:
- d p [ u ] [ j ] dp[u][j] dp[u][j] 的最优解是由多个子树合成的 d p [ u ] [ j ] = d p [ v 1 ] [ j 1 ] + d p [ v 2 ] [ j 2 ] + . . . + d p [ v n ] [ j n ] + v a l dp[u][j] = dp[v_1][j_1] + dp[v_2][j_2] + ... + dp[v_n][j_n] + val dp[u][j]=dp[v1][j1]+dp[v2][j2]+...+dp[vn][jn]+val 构成,我们不关心过程,只求最优,像不像 d p dp dp
细节
- 注释1:
flag
表示是不是第一个子节点 - 注释2:如果不是第一个子节点,我用一个
tmp[i]
数组来临时记录删去 i i i 个点路径的最优解 - 注释3:如果当前遍历的子树的点个数
cnt[vec] > i - j
,那么临时获得的最优解就是,当前dp[u][j]
里的最优解加上新儿子删去 i − j i - j i−j 个点的最优解dp[vec][i - j]
- 注释4:如果此时无法从当前儿子里取得最优解或者是说新儿子全删去,也就是
cnt[vec] <= i - j
,那我们也没必要加上边权新儿子什么的,最优解一定是在cnt[vec] == i - j
的时候出现 - 注释5:我们能保证
tmp[i]
就是当前dp[u][i]
的最优解,直接更新 - 注释6、注释7:如果是第一个子节点,那么此时的更新就是(可以理解为初始化赋值)
dp[u][i] = val + dp[vec][i]
。注意的一点是,如果当把子树的所有节点都不要的话,那么边权val
就完全没必要,也不会从下面获得任何路径长度,这些都是在子树的节点个数小于等于 k k k 的情况下cnt[vec] <= k
- 注释8:只有当子树怎么加都凑不够
i
i
i 个点的时候,才会出现
dp[u][i] >= LNF
的情况,此时更新为 0 0 0,根本什么也加不上 - 注释9:考虑到我们可能舍去 u u u 点上面的父亲节点的所有点
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const LL LNF = 1e18;
const int N = 1e4 + 10;
int h[N], e[N * 2], ne[N * 2], w[N * 2], idx;
LL dp[N][25];
LL ans;
int n, k;
int cnt[N];
void init() {
ans = LNF;
idx = 0;
memset(h, -1, sizeof h);
for (int i = 0; i < n; i ++ )
for (int j = 0; j <= k; j ++ )
dp[i][j] = LNF;
}
void add(int a, int b, int c) {
e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx ++;
}
void dfs(int u, int fa) {
cnt[u] = 1;
int flag = 0; //注释1
for (int i = h[u]; i != -1; i = ne[i]) {
int vec = e[i], val = w[i];
if (vec == fa) continue;
dfs(vec, u);
cnt[u] += cnt[vec];
if (flag) { //注释2
LL tmp[25];
for (int i = 0; i <= k; i ++ ) tmp[i] = LNF;
for (int i = 0; i <= k; i ++ )
for (int j = 0; j <= i; j ++ ) {
if (cnt[vec] > i - j) tmp[i] = min(tmp[i], dp[u][j] + dp[vec][i - j] + val); //注释3
else tmp[i] = min(tmp[i], dp[u][j]); //注释4
}
for (int i = 0; i <= k; i ++ )
dp[u][i] = tmp[i]; //注释5
} else {
for (int i = 0; i <= k; i ++ )
dp[u][i] = val + dp[vec][i]; //注释6
if (cnt[vec] <= k) dp[u][cnt[vec]] = 0; //注释7
}
flag = 1;
}
for (int i = 0; i <= k; i ++ )
if (dp[u][i] >= LNF) dp[u][i] = 0; //注释8
if (n - cnt[u] <= k)
ans = min(ans, dp[u][k - n + cnt[u]]); //注释9
}
int main() {
int t;
scanf("%d", &t);
while (t -- ) {
init();
scanf("%d%d", &n, &k);
for (int i = 1; i < n; i ++ ) {
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
add(u, v, w), add(v, u, w);
}
dfs(0, -1);
printf("%lld\n", ans * 2);
}
return 0;
}