题目描述
国家探险队长 Jack 意外弄到了一份秦始皇的藏宝图,于是,探险队一行人便踏上寻宝之旅,去寻找传说中的宝藏。
藏宝点分布在森林的各处,每个点有一个值,表示藏宝的价值。它们之间由一些小路相连,小路不会形成环,即两个藏宝点之间有且仅有一条道路。探险队从其中的一点出发,每次他们可以留一个人在此点开采宝藏,也可以不留,然后其余的人可以分成若干队向这一点相邻的点走去。需要注意的是,如果他们把队伍分成两队或两队以上,就必须留一个人在当前点,提供联络和通讯,当然这个人也可以一边开采此地的宝藏。并且,为了节约时间,队伍在前往开采宝藏过程中是不会走回头路的。现在你作为队长的助理,根据已有的藏宝图,请计算探险队所能开采的最大宝藏价值。
注意:在整个过程中,每个人最多只能开采一个点的宝藏。
输入格式
第 1 行有 2 个整数 n 和 m。其中 n 表示藏宝点的个数(1≤n≤100),m 表示探险队的人数(1≤m≤100)。
第 2 行是 n 个不超过 100 的整数,分别表示 1 到 n 每个点的宝藏价值。
接下来 n-1 行,每行两个数,x 和 y(1≤x,y≤n,x≠y),表示藏宝点 x 与 y 之间有一条路,数据保证不会有重复的路出现。
假设一开始探险队在点 1 处。
输出格式
输出一个整数,表示探险队所能获得最大宝藏价值。
样例数据 1
输入 [复制]
5 3
1 3 7 2 8
1 2
2 3
1 4
4 5
输出
16
备注
【数据范围】
对 40% 的输入数据 :1≤n≤30;m≤12。
对 100% 的输入数据 :1≤n≤100;m≤100。
解题思路:
很容易想到f[i][j]表示向子树i中派j个人所能收获的最大值:
若不分组:f[i][j]=max(f[son[i]][j])。
若分组,i点要留一人,共可以向下派i-1个人,那么按dfs遍历儿子,则:
f[i][j]=max(f[i][j],f[i][j-1-k]+f[son[i]][k]+val[i])。
但这是错误的,因为题目要求每个点只能挖一次,而我们不知道f[i][j-1-k]中是否算了val[i],所以做如下修改:
另开g[i][j]维护向子树i中派j个人,但不算val[i]所能收获的最大值,那第二个方程就变成了:
f[i][j]=max(f[i][j],g[i][j-1-k]+f[son[i]][k]+val[i])。这样就可以了。
注意修改f[i]或g[i]时要先用tmp数组来更新,避免枚举的当前子树的状态被自身的另一个状态转移了。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<set>
#define ll long long
using namespace std;
int getint()
{
int i=0,f=1;char c;
for(c=getchar();(c<'0'||c>'9')&&c!='-';c=getchar());
if(c=='-')f=-1,c=getchar();
for(;c>='0'&&c<='9';c=getchar())i=(i<<3)+(i<<1)+c-'0';
return i*f;
}
const int N=105;
int n,m,val[N],f[N][N],g[N][N],tmp[N];
int tot,first[N],nxt[N<<1],to[N<<1];
void add(int x,int y)
{
nxt[++tot]=first[x],first[x]=tot,to[tot]=y;
}
void dfs(int u,int fa)
{
int i,j;
f[u][1]=val[u];
for(int e=first[u];e;e=nxt[e])
{
int v=to[e];
if(v!=fa)
{
dfs(v,u);
for(i=1;i<=m;i++)tmp[i]=max(f[v][i],f[u][i]);
for(i=1;i<=m;i++)
for(j=1;j<i;j++)
tmp[i]=max(tmp[i],g[u][i-1-j]+f[v][j]+val[u]);
for(i=1;i<=m;i++)f[u][i]=tmp[i];
for(i=1;i<=m;i++)tmp[i]=max(f[v][i],g[u][i]);
for(i=1;i<=m;i++)
for(int j=1;j<i;j++)
tmp[i]=max(tmp[i],g[u][i-j]+f[v][j]);
for(i=1;i<=m;i++)g[u][i]=tmp[i];
}
}
}
int main()
{
//freopen("lx.in","r",stdin);
//freopen("lx.out","w",stdout);
int x,y;
n=getint(),m=getint();
for(int i=1;i<=n;i++)val[i]=getint();
for(int i=1;i<n;i++)
{
x=getint(),y=getint();
add(x,y),add(y,x);
}
dfs(1,0);
cout<<f[1][m];
return 0;
}