昨晚写完这题,交了三次没交上…
E. Paint the Tree
题意:给你一颗带边权的树,每个节点使用次数均为 k k k,你如果想要获得一条边的权值,那么必须要消耗该边相连的两个点的一次使用次数,问最多能获得多少的权值
解法:以1为根,设 d [ u ] [ 0 ] , d [ u ] [ 1 ] d[u][0],d[u][1] d[u][0],d[u][1]分别为 u u u子树中, u u u节点可使用次数为0和不为0所能获得的最大权值,假设 u u u有 T T T个儿子 s o n son son,我们把 d [ s o n i ] [ 0 ] , d [ s o n i ] [ 1 ] + w i , d [ s o n i ] [ 1 ] d[son_{i}][0],d[son_{i}][1]+w_{i},d[son_{i}][1] d[soni][0],d[soni][1]+wi,d[soni][1]存到优先队列,设 W i = d [ s o n i ] [ 1 ] + w i − d [ s o n i ] [ 0 ] W_{i}=d[son_{i}][1]+w_{i}-d[son_{i}][0] Wi=d[soni][1]+wi−d[soni][0],则优先队列按照 W i W_{i} Wi从大到小排序,如果有不超过 k − 1 k-1 k−1个的 W i W_{i} Wi大于0,我们最多只会消耗 u u u节点 k − 1 k-1 k−1个使用次数,那么可以直接转移: d [ u ] [ 0 ] = d [ u ] [ 1 ] = ∑ i = 1 T m a x ( d [ s o n i ] [ 0 ] , d [ s o n i [ 1 ] + w i ) d[u][0]=d[u][1]=\sum_{i=1}^{T}max(d[son_{i}][0], d[son_{i}[1]+w_{i}) d[u][0]=d[u][1]=∑i=1Tmax(d[soni][0],d[soni[1]+wi),如果有超过 k − 1 k-1 k−1个 W i W_{i} Wi大于0,那么 d [ u ] [ 0 ] = ( ∑ i = 1 k d [ s o n i ] [ 1 ] + w i ) + ∑ i = k + 1 T m a x ( d [ s o n i ] [ 0 ] , d [ s o n i ] [ 1 ] ) d[u][0]=(\sum_{i=1}^{k}d[son_{i}][1]+w_{i})+\sum_{i=k+1}^{T}max(d[son_{i}][0],d[son_{i}][1]) d[u][0]=(∑i=1kd[soni][1]+wi)+∑i=k+1Tmax(d[soni][0],d[soni][1]), d [ u ] [ 1 ] = ( ∑ i = 1 k − 1 d [ s o n i ] [ 1 ] + w i ) + ∑ i = k T m a x ( d [ s o n i ] [ 0 ] , d [ s o n i ] [ 1 ] ) d[u][1]=(\sum_{i=1}^{k-1}d[son_{i}][1]+w_{i})+\sum_{i=k}^{T}max(d[son_{i}][0],d[son_{i}][1]) d[u][1]=(∑i=1k−1d[soni][1]+wi)+∑i=kTmax(d[soni][0],d[soni][1])
#include<bits/stdc++.h>
#define pi pair<int, int>
#define mk make_pair
#define ll long long
using namespace std;
const int maxn = 5e5 + 10;
ll d[maxn][2];
int n, k;
vector<pi> G[maxn];
struct node {
ll w1, w2, w3;
bool operator<(const node& t) const {
return w2 - w1 < t.w2 - t.w1;
}
};
priority_queue<node> q;
void dfs(int u, int fa) {
for (auto tmp : G[u]) {
int v = tmp.first;
int w = tmp.second;
if (v == fa)
continue;
dfs(v, u);
}
for (auto tmp : G[u]) {
int v = tmp.first;
int w = tmp.second;
if (v == fa)
continue;
q.push(node{d[v][0], d[v][1] + w, d[v][1]});
}
int flag = 0;
for (int i = 1; i < k; i++) {
if (q.empty())
break;
ll w = max(q.top().w2, q.top().w1);
d[u][0] += w;
d[u][1] += w;
if (q.top().w1 == w)
flag = 1;
q.pop();
}
if (!q.empty()) {
ll w = max(q.top().w2, q.top().w1);
d[u][0] += w;
d[u][1] += max(q.top().w1, q.top().w3);
if (q.top().w1 == w)
flag = 1;
q.pop();
}
else
flag = 1;
while (!q.empty()) {
d[u][0] += max(q.top().w1, q.top().w3);
d[u][1] += max(q.top().w1, q.top().w3);
q.pop();
}
if (flag)
d[u][0] = d[u][1] = max(d[u][0], d[u][1]);
}
void solve() {
int u, v, w;
scanf("%d%d", &n, &k);
for (int i = 1; i < n; i++) {
scanf("%d%d%d", &u, &v, &w);
G[u].push_back(mk(v, w));
G[v].push_back(mk(u, w));
}
dfs(1, 0);
printf("%lld\n", max(d[1][0], d[1][1]));
for (int i = 1; i <= n; i++)
G[i].clear(), d[i][0] = d[i][1] = 0;
}
int main() {
int T;
scanf("%d", &T);
while (T--)
solve();
}