题目描述
有一棵点数为 n 的树,树边有边权。给你一个在 0∼n 之内的正整数 k ,你要在这棵树中选择 kk 个点,将其染成黑色,并将其他 的 n−k 个点染成白色。将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间的距离的和的受益。问受益最大值是多少。
输入格式
第一行包含两个整数 n,k。
第二到 nn 行每行三个正整数 fr, to, dis,表示该树中存在一条长度为 disdis 的边 (fr, to)。输入保证所有点之间是联通的。
输出格式
输出一个正整数,表示收益的最大值。
思路:
题面很简单,没有什么好说的。一个比较简单的思路是枚举k个点,计算最大的收益,但是这样非常容易T,如果n,k只有20的话也许可以考虑状压dp(?),但这里显然不行。点不行,我们就考虑边,计算每一条边可能对整体产生的贡献,假设这条边左侧有a个黑点,b个白点,右侧有c个黑点,d个白点,这条边的权值为w,那么这条边对整体来说产生的贡献就是
重新审视整个题目,于是想到树形dp,设dp【x】【j】是以x为根的子树,在有j个黑点的情况下,上面每一条边对整体产生的贡献(重点,整体而不是在子树上产生的贡献)。不妨设y是x的直接子节点,那么就得到如下转移方程
其中val是边<x,y>产生的贡献,计算方式为
w为这条边的权值,p是决策,即x树下面有p个黑点时的贡献,这条边右边的黑点就是p,左边的黑点是k-p,右边的白点是右边所有的点的个数减去黑点个数,即size【y】-p,左边白点数计算比较复杂,是所有的白点数(n-k)减去右边的白点数,结果是n-k-size【y】+p
dp三要素,阶段,状态,决策。用这三个去写dp转移方程。阶段就是到达子节点的所有边,这样我们就能知道x作为根节点的树的答案(在这里是对整体的贡献),状态在这里是根节点及黑点数,那么决策很明显就是子树的黑点数,实际上我们想出的方程应该再多一个维度,也就是阶段,但这一维可以被省略掉。在计算时,我们希望以如下方式计算的,实际上,这也确实帮助我们算出了val,对于dp【x】【k】,更准确的说法应该是dp【i】【x】【k】,也就是说,我们实际的方程应该是
out[y]的含义是y的出度,也就是y连向y的子节点的边全部处理完了的状态
dp的真正含义是到第i条边为止,一共选择了j个黑点对全局产生的贡献。这样子看,不就和分组背包一模一样了吗(有若干组物品,总共有V的容积,最大化装下物品的价值)。无非是多出了根节点作为一个维度,而这个维度我们完全可以放到无向图里去考虑,这个维度其实是为了遍历所有的边而存在的(当然不同的树形dp意义不同)
当然省略掉这一维后相应的j就要倒序循环,否则j-p转移就会出问题,对于i正序倒序无所谓,随便什么顺序都可以,对于p循环,一般顺序也没有关系,但是将实际方程与省略后的方程对比,不难发现p=0是一个特殊的情况,我们在由前一个阶段过来的时候,即使dp[x][j]的j没有变化,也得注意加上y全白的贡献。一个方法是正序p循环,另一个方法是预先处理掉这种情况然后逆序。
我们需要注意初值,因为我们需要保证后面两个状态是合法的,这样我们在整体赋初值时是置为-1,而对于j=0和1 的情况无论怎样都是合法的,但是贡献不知道,不妨置为0,这样只有j=1时有可能会更新答案,这与我们的预期一致
附代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
int head[2010],nxt[4010],ver[4010];
ll w[4010];
int n,k,tot=0;
ll dp[2010][2010];
int size[2010];
int v[2010];
int ans=-1;
const ll INF=1<<30;
const ll mod=1e9+7;
void add(int u,int v,ll z)
{
ver[++tot]=v;
nxt[tot]=head[u];
head[u]=tot;
w[tot]=z;
}
void dfs(int x)
{
v[x]=1;
int sum=1,t=1;
size[x]=1;
dp[x][0]=dp[x][1]=0;
for(int i=head[x];i!=-1;i=nxt[i])
{
int y=ver[i];
if(v[y])continue;
dfs(y);
size[x]+=size[y];
for(int j=min(k,size[x]);j>=0;j--)
{
if(dp[x][j]!=-1)dp[x][j]+=dp[y][0]+(ll)size[y]*(n-k-size[y])*w[i];
for(int p=min(j,size[y]);p;p--)
{
if(dp[x][j-p]==-1)continue;
ll val=(ll)(p*(k-p)+(size[y]-p)*(n-k-size[y]+p))*w[i];
dp[x][j]=max(dp[x][j],dp[x][j-p]+dp[y][p]+val);
}
}
}
}
int main()
{
scanf("%d%d",&n,&k);
if(n-k<k)k=n-k;
//memset(dp,~0x3f,sizeof(dp));
memset(head,-1,sizeof(head));
memset(dp,-1,sizeof(dp));
for(int i=1;i<n;i++)
{
int u,v;
ll w;
scanf("%d%d%lld",&u,&v,&w);
add(u,v,w);
add(v,u,w);
}
dfs(1);
printf("%lld\n",dp[1][k]);
return 0;
}