一.题目
传送门
翻译:
题意翻译:
给定一个有 N 个节点的树,每个节点要染上 K 种颜色,有无数多种颜色,每种颜色最多用两次。当一条边的两个节点附上的颜色中有至少一种相同颜色时,这条边的贡献就是它的权值,否则贡献为 0。
求这颗树所有边最大的贡献之和。
数据范围:
1
≤
T
≤
5
⋅
1
0
5
1\leq T \leq 5 · 10^5
1≤T≤5⋅105
1
≤
N
,
K
≤
5
⋅
1
0
5
1\leq N,K \leq 5 · 10^5
1≤N,K≤5⋅105
N
的
和
不
超
过
5
⋅
1
0
5
N的和不超过5·10^5
N的和不超过5⋅105
二.题解
这道题目有点巧妙,是一道树形DP。
既然有一个每个颜色最多不能用超过两次,所以每个点与其对应有贡献的边最多有k条。
定义f[i][0]为i不向它的父亲连边最大的贡献
f[i][1]为i向他的父亲连边最大的贡献
然后就有
f
[
f
a
]
[
0
]
=
f
[
f
a
]
[
1
]
=
∑
f
[
x
]
[
0
]
(
f
a
是
x
的
父
节
点
)
f[fa][0] = f[fa][1] = \sum f[x][0](fa是x的父节点)
f[fa][0]=f[fa][1]=∑f[x][0](fa是x的父节点)
于是如何运用f[x][1]呢?
我们处理一个f[x][0]转换成f[x][1]产生的贡献,即:
f
[
x
]
[
1
]
+
w
[
f
a
]
[
x
]
−
f
[
x
]
[
0
]
f[x][1] + w[fa][x] - f[x][0]
f[x][1]+w[fa][x]−f[x][0]
然后将其排序,从大到小取贡献值,不取贡献值为负的就行了。
sort (now.begin (), now.end (), cmp);//now中存了所有贡献,从大到小排序
for (int i = 0; i < now.size () && i < k; i ++){//不能超过k条边
if (now[i] <= 0)
break;
if (i < k - 1)//因为f[x][1]要向x的父亲连边,所以它的儿子向它连的边不能超过k-1条,i从0开始。
f[x][1] += now[i];
f[x][0] += now[i];
}
三.Code
#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
#define M 500005
#define LL long long
struct node {
int v, w;
node (){};
node (int V, int W){
v = V;
w = W;
}
};
int t, n, k;
LL ans, f[M][5];
vector <node> G[M];
void read (int &x){
x = 0; int f = 1; char c = getchar ();
while (c < '0' || c > '9') {if (c == '-') f = -1; c = getchar ();}
while (c >= '0' && c <= '9') {x = x * 10 + c - 48; c = getchar ();}
x *= f;
}
bool cmp (LL x, LL y){
return x > y;
}
void dfs (int x, int fa){
vector <LL> now;
for (int i = 0; i < G[x].size (); i ++){
int tmp = G[x][i].v, tot = G[x][i].w;
if (tmp != fa){
dfs (tmp, x);
f[x][0] = f[x][0] + f[tmp][0];
f[x][1] = f[x][1] + f[tmp][0];
now.push_back (f[tmp][1] + 1ll * tot - f[tmp][0]);
}
}
sort (now.begin (), now.end (), cmp);
for (int i = 0; i < now.size () && i < k; i ++){
if (now[i] <= 0)
break;
if (i < k - 1)
f[x][1] += now[i];
f[x][0] += now[i];
}
}
int main (){
read (t);
while (t --){
read (n), read (k);
for (int i = 1; i <= n; i ++){
G[i].clear ();
f[i][0] = f[i][1] = 0;
}
for (int i = 1; i < n; i ++){
int u, v, w;
read (u), read (v), read (w);
G[u].push_back (node (v, w));
G[v].push_back (node (u, w));
}
dfs (1, 0);
printf ("%lld\n", f[1][0]);
}
return 0;
}