题目链接
思路:
设
d
p
[
i
]
[
0
]
dp[i][0]
dp[i][0]为子树包括自身,组成联通快且每个点的度数都<k的最大可能。
设
d
p
[
i
]
[
1
]
dp[i][1]
dp[i][1]为子树包括自身,组成联通快且至多一个点度数>=k最大可能。
那么随便从一个点出发,开始dfs。
有一种情况是dp无法解决的,就是有个点选择了k个子树,且被选中的一个子树联通快中存在>=k度数的点的情况。这时候显然就是在上面的情况中多加了一个子树dp[son][0]+weight。。。
细节见代码:
#pragma GCC optimize(2)
#pragma GCC optimize(3)
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 5e5 + 10;
#define fi first
#define se second
#define pb push_back
#define wzh(x) cerr<<#x<<'='<<x<<endl;
int n, K, t;
vector<pair<int, int> >v[N];
LL dp[N][3], ans;
pair<int, int>G[N];
int cnt;
void dfs(int x, int y) {
for (auto k : v[x]) {
if (k.fi != y)dfs(k.fi, x);
}
cnt = 0;
for (auto k : v[x]) {
if (k.fi != y) {
G[++cnt] = k;
}
}
sort(G + 1, G + 1 + cnt, [](pair<int, int>x, pair<int, int>y) {
if (dp[x.fi][0] + x.se == dp[y.fi][0] + y.se)return dp[x.fi][1] + x.se > dp[y.fi][1] + y.se;
return dp[x.fi][0] + x.se > dp[y.fi][0] + y.se;
});
for (int i = 1; i < K && i <= cnt; i++) {
dp[x][0] += dp[G[i].fi][0] + G[i].se; //
}
for (int i = 1; i <= cnt; i++) {
dp[x][1] += dp[G[i].fi][0] + G[i].se;
}
LL s = 0;
for (int i = 1; i <= min(K - 1, cnt); i++)s += dp[G[i].fi][0] + G[i].se;
if (cnt < K) {
for (int i = 1; i <= cnt; i++) {
dp[x][1] = max(dp[x][1], s - dp[G[i].fi][0] + dp[G[i].fi][1]);
}
} else {
LL dx = LLONG_MIN;
for (int i = 1; i <= K - 1; i++) {
dx = max(dx, s - dp[G[i].fi][0] + dp[G[i].fi][1]);//从前k-1个里面选一个当1
}
dp[x][1] = max(dp[x][1], dx);
LL ds = s;
s -= dp[G[K - 1].fi][0] + G[K - 1].se;
if (K > 1) {
for (int i = K; i <= cnt; i++) {
dp[x][1] = max(dp[x][1], s + dp[G[i].fi][1] + G[i].se);
}
}
ds += dp[G[K].fi][0] + G[K].se;
for (int i = 1; i <= K; i++) {
ans = max(ans, ds - dp[G[i].fi][0] + dp[G[i].fi][1]);
}
ds -= dp[G[K].fi][0] + G[K].se;
for (int i = K + 1; i <= cnt; i++) {
ans = max(ans, ds + dp[G[i].fi][1] + G[i].se);
}
}
ans = max({ans, dp[x][0], dp[x][1]});
}
int main() {
// freopen("rand.txt", "r", stdin);
// freopen("my.txt", "w", stdout);
for (scanf("%d", &t); t; t--) {
// cout << "T=" << t << ' ';
scanf("%d%d", &n, &K); ans = LLONG_MIN;
for (int i = 1; i <= n; i++) {
v[i].clear();
dp[i][0] = dp[i][1] = 0;
}
int mx = 0;
for (int i = 1; i < n; i++) {
int x, y, z;
scanf("%d%d%d", &x, &y, &z);
v[x].pb({y, z});
v[y].pb({x, z});
mx = max(mx, z);
}
if (K == 0) {
puts("0");
continue;
}
dfs(1, 0);
printf("%lld\n", ans);
}
return 0;
}