树形dp吧,状态挺显然的,dp[x][j]表示以x为根的子树中,选择了j个黑点的答案,但注意这个答案是整棵树的答案。
我们只需要对于每个儿子背包一遍,在最后更新一下dp[x][j]即可,具体可以看一眼程序。
非常重要的是,这个复杂度是n^2的,需要注意的是,如果我们要保证复杂度,for(int j=size[x];~j;j–)for(int k=size[ver[i]];~k;k–)必须要这么写,这样实际上是枚举整棵树中两两点对之间的lca,复杂度n^2就比较显然了。
差评下别的好多题解没有说复杂度也没有证明,我找了几份题解以为n^3卡了半天发现卡不掉仔细理性分析了下才发现是n^2的,get到了树形dp的正确姿势…
如果有人因为背包挂掉请注意是不是j和k都是倒着枚举的,如果因为这里挂掉请仔细想想原因(蒟蒻表示自己也因为这里wa掉了一屏…对于背包的处理不好…)最保险的做法是memcpy,但那样常数略大…
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
//by:MirrorGray
using namespace std;
const int N=2111;
ll dp[N][N],tmp[N];
int n,k,tot=-1,size[N],head[N],ver[N<<1],nxt[N<<1],len[N<<1];
void add(int x,int y,int z){
nxt[++tot]=head[x];
head[x]=tot;
ver[tot]=y;
len[tot]=z;
}
void dfs(int x,int f,int L){
size[x]=1;
for(int i=head[x];~i;i=nxt[i])if(ver[i]!=f){
dfs(ver[i],x,len[i]);
for(int j=size[x];~j;j--)for(int k=size[ver[i]];~k;k--)
dp[x][j+k]=max(dp[x][j+k],dp[x][j]+dp[ver[i]][k]);
// memcpy(tmp,dp[x],sizeof(dp[x]));
// for(int j=0;j<=size[x];j++)for(int k=0;k<=size[ver[i]];k++)
// tmp[j+k]=max(tmp[j+k],dp[x][j]+dp[ver[i]][k]);
// memcpy(dp[x],tmp,sizeof(tmp));
size[x]+=size[ver[i]];
}
for(int j=0;j<=size[x];j++)dp[x][j]+=(ll)j*(k-j)*L+(ll)(size[x]-j)*(n-size[x]-(k-j))*L;
}
int main(){
memset(head,-1,sizeof(head));
scanf("%d%d",&n,&k);
for(int i=1;i<n;i++){
int a,b,c;scanf("%d%d%d",&a,&b,&c);
add(a,b,c);add(b,a,c);
}
dfs(1,-1,0);
printf("%lld\n",dp[1][k]);
return 0;
}