[HAOI2015] 树上染色
题目描述
有一棵点数为 n n n 的树,树边有边权。给你一个在 0 ∼ n 0 \sim n 0∼n 之内的正整数 k k k ,你要在这棵树中选择 k k k 个点,将其染成黑色,并将其他的 n − k n-k n−k 个点染成白色。将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间的距离的和的收益。问收益最大值是多少。
输入格式
第一行包含两个整数 n , k n,k n,k。
第二到 n n n 行每行三个正整数 u , v , w u, v, w u,v,w,表示该树中存在一条长度为 w w w 的边 ( u , v ) (u, v) (u,v)。输入保证所有点之间是联通的。
输出格式
输出一个正整数,表示收益的最大值。
样例 #1
样例输入 #1
3 1
1 2 1
1 3 2
样例输出 #1
3
提示
对于 100 % 100\% 100% 的数据, 0 ≤ n , k ≤ 2000 0 \leq n,k \leq 2000 0≤n,k≤2000。
分析
一看枚举所有染色情况就不现实,所以这种最值一般都是DP(组合数也都算不出来),而且允许 n 2 n^2 n2 的复杂度,满足树形DP。
这种没法枚举所有情况的时候必然要考虑从中间某个地方入手(参考Moving Dots),所以考虑对每条边计算贡献。因为是一棵树,所以贡献就与两侧黑白点的数目有关,对于每一条边,设一侧子树黑点为k个,题目要求选的黑点有m个,当前子树大小为 s i z [ k ] siz[k] siz[k] ,分数 t o t = k ∗ ( m − k ) + ( s i z [ v ] − k ) ∗ ( n − ( m − k ) − s i z [ v ] ) tot=k*(m-k)+(siz[v]-k)*(n-(m-k)-siz[v]) tot=k∗(m−k)+(siz[v]−k)∗(n−(m−k)−siz[v])
有了这个计算式,就大概可以树形DP了, d p [ u ] [ i ] dp[u][i] dp[u][i] 表示以 u u u 为跟的子树中,选择 i i i 个黑节点,对答案有多少贡献
转移方程为 d p [ u ] [ i ] = m a x ( d p [ u ] [ i ] , d p [ u ] [ i − j ] + d p [ v ] [ j ] + t o t ∗ e [ i ] . w ) dp[u][i] = max( dp[u][i], dp[u][i-j] + dp[v][j] + tot*e[i].w) dp[u][i]=max(dp[u][i],dp[u][i−j]+dp[v][j]+tot∗e[i].w)
其中 v v v 为 u u u 的子节点, j j j 为在这个子节点中选择的黑色点的个数, t o t tot tot 为这条边的贡献
会发现这个转移非常有背包的样子,枚举的时候也按照背包做就行,注意判好边界。
上代码
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
typedef long long ll;
struct node
{
int to,next,w;
}e[4010];
int n,m,siz[2010];
int hd[4010],tot;
ll f[2010][2010],ans;
void add(int x,int y,int w)
{
e[++tot]=(node){y,hd[x],w};
hd[x]=tot;
}
void dp(int x,int fx)
{
siz[x]=1;
f[x][0]=f[x][1]=0;
for(int i=hd[x];i;i=e[i].next)
{
int v=e[i].to;
if(v==fx) continue;
dp(v,x);
/*树上背包*/
siz[x]+=siz[v];
for(int j=min(m,siz[x]);j>=0;j--)
{
for(int k=0;k<=min(j,siz[v]);k++)
{
if(f[x][j-k]==-1) continue;
ll tot=k*(m-k)+(siz[v]-k)*(n-m-siz[v]+k);
f[x][j]=max(f[x][j],f[v][k]+f[x][j-k]+tot*e[i].w);
}
}
}
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n-1;i++)
{
int x,y,w;
scanf("%d%d%d",&x,&y,&w);
add(x,y,w);
add(y,x,w);
}
memset(f,-1,sizeof(f));
dp(1,0);
cout<<f[1][m];
return 0;
}