洛谷传送门
BZOJ传送门
题目描述
有一棵点数为 N N N 的树,树边有边权。给你一个在 0 ∼ N 0 \sim N 0∼N 之内的正整数 K K K ,你要在这棵树中选择 K K K个点,将其染成黑色,并将其他 的 N − K N-K N−K个点染成白色 。 将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间的距离的和的受益。问受益最大值是多少。
输入输出格式
输入格式:
第一行包含两个整数 N , K N, K N,K 。接下来 N − 1 N-1 N−1 行每行三个正整数 f r , t o , d i s fr, to, dis fr,to,dis , 表示该树中存在一条长度为 d i s dis dis 的边 ( f r , t o ) (fr, to) (fr,to) 。输入保证所有点之间是联通的。
输出格式:
输出一个正整数,表示收益的最大值。
输入输出样例
输入样例#1:
3 1
1 2 1
1 3 2
输出样例#1:
3
说明
对于 100 % 100\% 100% 的数据, 0 ≤ K ≤ N ≤ 2000 0\le K\le N \le 2000 0≤K≤N≤2000
解题分析
这道题的 d p dp dp状态不是很好设出来, 我们可以这样想:所有同色点对之间的距离之和是由每条边出现多次构成的, 而一条边的贡献次数就是这条边两侧同色点个数的乘积, 所以我们可以设 d p [ i ] [ j ] dp[i][j] dp[i][j]表示 i i i号点为根的子树中染成黑色点的个数为 j j j的时候子树内部贡献和最大为多少, 这样可以做到 O ( N 2 ) O(N^2) O(N2)复杂度。
代码如下:
#include <cstdio>
#include <cctype>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#define R register
#define IN inline
#define W while
#define gc getchar()
#define MX 2050
#define ll long long
template <class T>
IN void in(T &x)
{
x = 0; R char c = gc;
for (; !isdigit(c); c = gc);
for (; isdigit(c); c = gc)
x = (x << 1) + (x << 3) + c - 48;
}
template <class T> IN T max(T a, T b) {return a > b ? a : b;}
template <class T> IN T min(T a, T b) {return a < b ? a : b;}
int dot, lim, cnt;
ll dp[MX][MX], buf[MX];
int head[MX], siz[MX];
struct Edge {int to, len, nex;} edge[MX << 1];
IN void add(R int from, R int to, R int len)
{edge[++cnt] = {to, len, head[from]}, head[from] = cnt;}
void DP(R int now, R int fa)
{
dp[now][1] = dp[now][0] = 0; siz[now] = 1;
R int bd, arr, mx;
for (R int i = head[now]; i; i = edge[i].nex)
{
if (edge[i].to ^ fa)
{
DP(edge[i].to, now); siz[now] += siz[edge[i].to];
bd = min(siz[now], lim);
for (R int j = bd; ~j; --j)
{
mx = min(siz[edge[i].to], j);
for (R int k = 0; k <= mx; ++k)
dp[now][j] = max(dp[now][j], dp[now][j - k] + dp[edge[i].to][k] + (k * (lim - k) + 1ll * (siz[edge[i].to] - k) * (dot - lim - siz[edge[i].to] + k)) * edge[i].len);
}
}
}
}
int main(void)
{
int a, b, c;
in(dot), in(lim);
for (R int i = 1; i < dot; ++i)
in(a), in(b), in(c), add(a, b, c), add(b, a, c);
std::memset(dp, -63, sizeof(dp));
DP(1, 0);
printf("%lld", dp[1][lim]);
}