这题其实就是一个树上背包的裸题,但是因为刚刚接触背包,因此萌新感觉做起来很困难,所以就写了这篇题解用来加深印象。
我们可以逐段代码进行分析。
首先是主函数,这没啥好说的,其实就是输入n和k和加入边之间的关系,记得要双边加哦,然后一个初始化,一个从1开始的0为父节点的dfs和一个输出最终结果的输出,这就是主函数的全部了。为什么要初始化呢,我的理解是这样的,就是因为我们可能会遇上一些没有被转移到的转移状态,当这些转移状态没出现的时候,那么就是-1了,那我们就可以不使用这些状态进行转移。
int main(){
scanf("%lld%lld",&n,&k);
for(long long int i=1;i<=n-1;i++){
long long int x,y,z;
cin>>x>>y>>z;
v[x].push_back({y,z});
v[y].push_back({x,z});
}
memset(f,-1,sizeof f);
dfs(1,0);
cout<<f[1][k];
}
接下来看看最重要部分组成的dfs
void dfs(int x,int fa){
if(book[x])//用来标记走没走,但想想其实没啥必要,因为本来就只会走一次
return;
siz[x]=1;//初始赋值这个就只有本身这一个结点
book[x]=1;
f[x][1]=0,f[x][0]=0;//初始赋值,因为无论你这个结点是只有一个黑点还是0个黑点没法转移都是0
// cout<<x<<" "<<siz[x]<<endl;
for(int i=0;i<(int)v[x].size();i++){//遍历所有的子节点
int y=v[x][i].first;
if(y==fa)//如果遇到子节点是父节点的话就跳过
continue;
dfs(y,x);siz[x]+=siz[y];//递归找子树总结点
for(long long int j=min(k,siz[x]);j>=0;j--){
for(long long int l=0;l<=min(j,siz[y]);l++){
if(f[x][j-l]==-1) continue;//如果出现不符合的状态就跳过,没这一步只有80
//这个res其实是黑点与黑点之间还有白点与白点之间总共经过了这个边多少次的值
long long res=(l*(k-l)+(n-siz[y]-(k-l))*(siz[y]-l))*v[x][i].second;
f[x][j]=max(f[x][j],f[x][j-l]+f[y][l]+res);//更新状态
}
}
}
}
因此总代码是如下:
#include<stdio.h>
#include<iostream>
#include<algorithm>
#include<queue>
#include<vector>
#include<cstring>
using namespace std;
typedef pair<long long int,long long int>PII;
long long f[2005][2005];
vector<PII> v[2005];
bool book[2005];
long long int siz[2005];
long long int n,k;
void dfs(int x,int fa){
if(book[x])
return;
siz[x]=1;
book[x]=1;
f[x][1]=0,f[x][0]=0;
// cout<<x<<" "<<siz[x]<<endl;
for(int i=0;i<(int)v[x].size();i++){
int y=v[x][i].first;
if(y==fa)
continue;
dfs(y,x);siz[x]+=siz[y];
for(long long int j=min(k,siz[x]);j>=0;j--){
for(long long int l=0;l<=min(j,siz[y]);l++){
if(f[x][j-l]==-1) continue;
long long res=(l*(k-l)+(n-siz[y]-(k-l))*(siz[y]-l))*v[x][i].second;
f[x][j]=max(f[x][j],f[x][j-l]+f[y][l]+res);
}
}
}
}
int main(){
scanf("%lld%lld",&n,&k);
for(long long int i=1;i<=n-1;i++){
long long int x,y,z;
cin>>x>>y>>z;
v[x].push_back({y,z});
v[y].push_back({x,z});
}
memset(f,-1,sizeof f);
dfs(1,0);
// for(int i=1;i<=n;i++)
// cout<<siz[i]<<endl;
cout<<f[1][k];
}
如果有什么没讲清楚地可以在评论区下询问哦