Description
有一棵点数为N的树,树边有边权。给你一个在0~N之内的正整数K,你要在这棵树中选择K个点,将其染成黑色,并
将其他的N-K个点染成白色。将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间距离的和的收益。
问收益最大值是多少。
Input
第一行两个整数N,K。
接下来N-1行每行三个正整数fr,to,dis,表示该树中存在一条长度为dis的边(fr,to)。
输入保证所有点之间是联通的。
N<=2000,0<=K<=N
Output
输出一个正整数,表示收益的最大值。
Sample Input
5 2
1 2 3
1 5 1
2 3 1
2 4 2
Sample Output
17
【样例解释】
将点1,2染黑就能获得最大收益。
题解
树形DP。
首先不难想到设
fi,j
为以编号为i的节点为根的子树中有j个黑色节点对答案的贡献。
这里发现不好转移,所以把该子树内的点与子树外的点组合所产生的权值也计算进去。考虑统计所有边权对答案的贡献,一条边对答案产生的贡献为
边权∗(子树内黑色点数量∗子树外黑色点数量+子树内白色点数量∗子树外白色点数量)
。
用DFS来求,枚举 i 的每个儿子 j,现在的
f[i]
是包含了
[1,j−1]
子树,然后两重循环枚举范围是
[1,j−1]
的子树总 Size 和 j 的 Size,来更新
f[i]
,这样更新之后的
f[i]
就是
[1,j]
子树的答案了。
通过奥妙重重的方法可以发现每个点对
(u,v)
只会在其
lca
处被考虑到,所以复杂度是
O(N)
。
//这一版图方便结果跑得奇慢无比,虽然是过了。
#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
inline int read(){
int x = 0, f = 1; char c = getchar();
while(c < '0' || c > '9') { if(c == '-') f = -1; c = getchar(); }
while(c >= '0' && c <='9') { x = x * 10 + c - '0'; c = getchar(); }
return x * f;
}
const int N = 2000 + 10, inf = 0x3f3f3f3f;
int to[N<<1], val[N<<1], hd[N], nxt[N<<1], tot;
int siz[N];
int n, k;
ll f[N][N];
void insert(int u, int v, int w){
to[++tot] = v; val[tot] = w; nxt[tot] = hd[u]; hd[u] = tot;
to[++tot] = u; val[tot] = w; nxt[tot] = hd[v]; hd[v] = tot;
}
void init(){
n = read(); k = read();
int u, v, w;
for(int i = 1; i < n; i++){
u = read(); v = read(); w = read();
//printf("%d %d %d\n", u, v, w);
insert(u, v, w);
}
}
void dfs(int u, int fa){
f[u][0] = f[u][1] = 0;
siz[u] = 1;
for(int i = hd[u]; i; i = nxt[i]){
int v = to[i];
if(v == fa) continue;
dfs(v, u);
siz[u] += siz[v];
for(int x = min(k, siz[u]); x >= 0; x--)
for(int y = 0; y <= min(siz[v], x); y++)
f[u][x] = max(f[u][x], f[u][x-y] + f[v][y] + (ll)val[i] * ((ll)(y*(k-y)) + (ll)(siz[v]-y)*(ll)(n-k-(siz[v]-y))));
}
}
void work(){
memset(f, 128, sizeof(f));
dfs(1, 0);
cout<<f[1][k]<<endl;
}
int main(){
init();
work();
return 0;
}
//把每条边的权值放到子树里处理,时间表现一下子正常了许多。
#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
inline int read(){
int x = 0, f = 1; char c = getchar();
while(c < '0' || c > '9') { if(c == '-') f = -1; c = getchar(); }
while(c >= '0' && c <='9') { x = x * 10 + c - '0'; c = getchar(); }
return x * f;
}
const int N = 2000 + 10, inf = 0x3f3f3f3f;
int to[N<<1], val[N<<1], hd[N], nxt[N<<1], tot;
int siz[N];
int n, k;
ll f[N][N], tmp[N];
void insert(int u, int v, int w){
to[++tot] = v; val[tot] = w; nxt[tot] = hd[u]; hd[u] = tot;
to[++tot] = u; val[tot] = w; nxt[tot] = hd[v]; hd[v] = tot;
}
void init(){
n = read(); k = read();
int u, v, w;
for(int i = 1; i < n; i++){
u = read(); v = read(); w = read();
insert(u, v, w);
}
}
void dfs(int u, int fa, int pre){
siz[u] = 1;
for(int i = hd[u]; i; i = nxt[i]){
int v = to[i];
if(v == fa) continue;
dfs(v, u, val[i]);
for(int j = 0; j <= min(siz[u], k); j++) tmp[j] = f[u][j];
for(int x = 0; x <= min(siz[u], k); x++)
for(int y = 0; y <= min(siz[v], k); y++)
f[u][x+y] = max(f[u][x+y], tmp[x] + f[v][y]);
siz[u] += siz[v];
}
for(int i = 0; i <= min(siz[u], k); i++)
f[u][i] += (ll)pre * ((ll)i*(k-i) + (ll)(siz[u]-i)*(n-k-(siz[u]-i)));
}
void work(){
dfs(1, 0, 0);
cout<<f[1][k]<<endl;
}
int main(){
init();
work();
return 0;
}