【描述】
有一棵点数为 N 的树,树边有边权。给你一个在 0∼N 之内的正整数K,你要在这棵树中选择 K 个点,将其染成黑色,并将其他的 N−K个点染成白色。将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间距离的和的收益。
问收益最大值是多少。
【输入】
第一行两个整数 N,K。
接下来 N−1 行每行三个正整数 fr,to,dis,表示该树中存在一条长度为 dis 的边 (fr,to)
输入保证所有点之间是联通的。
【输出】
输出一个正整数,表示收益的最大值。
【样例输入】
5 2
1 2 3
1 5 1
2 3 1
2 4 2
【样例输出】
17
【思路】
显然,这是一道树上背包dp。状态定义也十分显然
f
[
i
]
[
j
]
f[i][j]
f[i][j]表示i子树内有j个黑点的贡献。
1.考虑转移
我们首先考虑一对点对于i子树的贡献,发现我们必须知道点的具体位置才能计算贡献。考虑换一个思路,我们考虑子树内每一条边的贡献。这就很好办了,因为这与点的位置无关了(设黑点共m个,w为当前枚举的边的边权):
f
[
u
]
[
i
]
=
m
a
x
(
f
[
u
]
[
i
]
,
f
[
u
]
[
i
−
j
]
+
f
[
v
]
[
j
]
+
w
∗
(
j
∗
(
m
−
j
)
+
(
s
i
z
[
v
]
−
j
)
∗
(
n
−
m
−
(
s
i
z
[
v
]
−
j
)
)
)
)
f[u][i]=max(f[u][i],f[u][i-j]+f[v][j]+w*(j*(m-j)+(siz[v]-j)*(n-m-(siz[v]-j))))
f[u][i]=max(f[u][i],f[u][i−j]+f[v][j]+w∗(j∗(m−j)+(siz[v]−j)∗(n−m−(siz[v]−j))))
2.考虑枚举顺序
这道题和一般的背包dp不同的地方在于,
f
[
v
]
[
0
]
f[v][0]
f[v][0]对u有贡献,这就限制了我们不能鲁莽地从大到小或者从小到大枚举。这时候我们应该想清楚这一类背包dp的本质:这表面上看起来是一个二维的dp,实际上是一个"三维"的dp,还有一维就是当前枚举到了第几棵子树。这里相当于采取了滚动数组的思想。而这样一来,上面的转移方程成立,一个必要条件就是当我们用
f
[
u
]
[
i
−
j
]
f[u][i-j]
f[u][i−j]来更新
f
[
u
]
[
i
]
f[u][i]
f[u][i]时,
f
[
u
]
[
i
−
j
]
f[u][i-j]
f[u][i−j]仍然是没有考虑v子树的答案。所以i应该从大到小枚举。我们再来考虑j,这就十分考验我们的逻辑了。假设我们从大到小枚举,那么
f
[
u
]
[
i
]
f[u][i]
f[u][i]就会被更新,而最后
j
=
0
j=0
j=0时我们会用
f
[
u
]
[
i
]
f[u][i]
f[u][i]来更新
f
[
u
]
[
i
]
f[u][i]
f[u][i],根据刚刚我们的结论,用来更新其他状态的状态必须还没有考虑子树v,但是现在
f
[
u
]
[
i
]
f[u][i]
f[u][i]已经考虑过子树v了,这显然就错了。对于一般的背包dp,顺序之所以无关紧要,是因为
f
[
v
]
[
0
]
f[v][0]
f[v][0]一般对答案没有贡献。所以这里我们必须保证枚举
j
=
0
j=0
j=0,由于接下来的枚举与我们正在更新的
f
[
u
]
[
i
]
f[u][i]
f[u][i]没有交集,所以接下来从大到小还是从小到大枚举j就无关紧要了。
代码:
#include<bits/stdc++.h>
#define re register
#define mp make_pair
using namespace std;
const int N=2e3+5;
int n,m,a,b,c;
inline int red(){
int data=0;int w=1; char ch=0;
ch=getchar();
while(ch!='-' && (ch<'0' || ch>'9')) ch=getchar();
if(ch=='-') w=-1,ch=getchar();
while(ch>='0' && ch<='9') data=(data<<3)+(data<<1)+ch-'0',ch=getchar();
return data*w;
}
int siz[N];
vector< pair<int,int> >g[N];
inline void add(int u,int v,int w){g[u].push_back(mp(v,w));}
long long f[N][N];
void dfs(int u){
f[u][0]=f[u][1]=0;siz[u]=1;
for(int re i=g[u].size()-1;~i;--i){
int v=g[u][i].first,w=g[u][i].second;
if(siz[v])continue;dfs(v);siz[u]+=siz[v];
for(int re i=min(siz[u],m);~i;--i)
for(int re j=0;j<=min(m,siz[v])&&j<=i;++j)
f[u][i]=max(f[u][i],f[u][i-j]+f[v][j]+1ll*((m-j)*j+(siz[v]-j)*(n-siz[v]-(m-j)))*w);
}
}
int main(){
memset(f,128,sizeof(f));
n=red();m=red();m=min(m,n-m);
for(int re i=1;i^n;i++){a=red();b=red();c=red();add(a,b,c);add(b,a,c);}
dfs(1);cout<<f[1][m]<<"\n";
}