题意
给定一棵树,找出\(k\)个点,使\(\sum dis(A_i,A_{i+1})\)最小。
思路
显然选出的点构成一颗树,对于这样的树最优的排列方式使直径仅计算一次,其他边计算两次。
所以得出子状态\(f[i][j][0/1/2]\)表示以i为根的子树中选了j个点,这j个点中包含了0/1/2个直径端点的最短长度。
状态转移方程分类讨论,还好,不算毒瘤。
f[now][j+k][0]=min(f[now][j+k][0],f[now][j][0]+f[to][k][0]+edge[i].val*2);
f[now][j+k][1]=min(f[now][j+k][1],f[now][j][0]+f[to][k][1]+edge[i].val);
f[now][j+k][1]=min(f[now][j+k][1],f[now][j][1]+f[to][k][0]+edge[i].val*2);
f[now][j+k][2]=min(f[now][j+k][2],f[now][j][2]+f[to][k][0]+edge[i].val*2);
f[now][j+k][2]=min(f[now][j+k][2],f[now][j][0]+f[to][k][2]+edge[i].val*2);
f[now][j+k][2]=min(f[now][j+k][2],f[now][j][1]+f[to][k][1]+edge[i].val);
\(j\)枚举的是\(now\)的子树除了\(to\)的子树以外的节点。
代码
#include <bits/stdc++.h>
using namespace std;
namespace StandardIO {
template<typename T> inline void read (T &x) {
x=0;T f=1;char c=getchar();
for (; c<'0'||c>'9'; c=getchar()) if (c=='-') f=-1;
for (; c>='0'&&c<='9'; c=getchar()) x=x*10+c-'0';
x*=f;
}
template<typename T> inline void write (T x) {
if (x<0) putchar('-'),x=-x;
if (x>=10) write(x/10);
putchar(x%10+'0');
}
}
using namespace StandardIO;
namespace Solve {
const int N=3003;
int n,K,ans=0x3f3f3f3f;
int cnt;
int head[N];
struct node {
int to,next,val;
} edge[N<<1];
int f[N][N][3],size[N];
inline void add (int a,int b,int c) {
edge[++cnt].to=b,edge[cnt].val=c,edge[cnt].next=head[a],head[a]=cnt;
}
void dp (int now,int fa) {
size[now]=1,f[now][1][0]=f[now][1][1]=f[now][1][2]=0;
for (register int i=head[now]; i; i=edge[i].next) {
int to=edge[i].to;
if (to==fa) continue;
dp(to,now);
for (register int j=min(size[now],K); j>=1; --j) {
for (register int k=1; k<=min(size[to],K); ++k) {
if (j+k>K) continue;
f[now][j+k][0]=min(f[now][j+k][0],f[now][j][0]+f[to][k][0]+edge[i].val*2);
f[now][j+k][1]=min(f[now][j+k][1],f[now][j][0]+f[to][k][1]+edge[i].val);
f[now][j+k][1]=min(f[now][j+k][1],f[now][j][1]+f[to][k][0]+edge[i].val*2);
f[now][j+k][2]=min(f[now][j+k][2],f[now][j][2]+f[to][k][0]+edge[i].val*2);
f[now][j+k][2]=min(f[now][j+k][2],f[now][j][0]+f[to][k][2]+edge[i].val*2);
f[now][j+k][2]=min(f[now][j+k][2],f[now][j][1]+f[to][k][1]+edge[i].val);
}
}
size[now]+=size[to];
}
}
inline void MAIN () {
read(n),read(K);
for (register int i=1; i<n; ++i) {
int x,y,z;
read(x),read(y),read(z);
add(x,y,z),add(y,x,z);
}
memset(f,0x3f,sizeof(f));
dp(1,1);
for (register int i=1; i<=n; ++i) {
ans=min(ans,f[i][K][2]);
}
write(ans);
}
}
int main () {
Solve::MAIN();
}