Description
K妹的胡椒粉大卖,这辣味让食客们感到刺激,许多餐馆也买这位K妹的账。有N家餐馆,有N-1条道路,这N家餐馆能相互到达。K妹从1号餐馆开始。每一个单位时间,K妹可以在所在餐馆卖完尽量多的胡椒粉,或者移动到有道路直接相连的隔壁餐馆。第i家餐馆最多需要A[i]瓶胡椒粉。K妹有M个单位的时间,问她最多能卖多少胡椒粉。
Input
第一行有两个正整数N,M。
第二行描述餐馆对胡椒粉的最大需求量,有N个正整数,表示A[i]。
接下来有N-1行描述道路的情况,每行两个正整数u,v,描述这条道路连接的两个餐馆。
Output
一个整数,表示她最多能卖的胡椒粉瓶数。
Sample Input
样例1输入
3 5
9 2 5
1 2
1 3
样例2输入
4 5
1 1 1 2
1 2
2 3
3 4
样例3输入
5 10
1 3 5 2 4
5 2
3 1
2 3
4 2
Sample Output
样例1输出
14
样例2输出
3
样例3输出
15
Data Constraint
对于10%的数据,N≤20。
对于50%的数据,N≤110。
对于100%的数据1 ≤ N, M ≤ 500,1 ≤ A[i]≤ 10^6,
第5到第10个测试点都有多个子测试。
Hint
在样例1的中,辣妹到达城市2后就恰好没时间卖辣椒粉了。
思路
这是一道树形 dp 的题目。
我们可以设 dp[i][j][0/1]
Dp[i][j][0]表示以 i 为根的子树中,花费 j 单位时间,最
终回到 i 的最大收益。
Dp[i][j][1]表示以 i 为根的子树中,花费 j 单位时间,最
终不必回到 i 的最大收益。
转移的时候,一颗一颗子树来做。
将已做的子树信息存在 dp[i][][]里面。
枚举下一颗子树的时候,同时枚举在下一颗子树花费
的时间 k,以及在之前子树花费的时间 j。
那么转移就很显然了。
先走之前的子树回到根,再走下一颗子树不回根:
dp[i][j][0] + dp[son][k][1] –> dp[i][j+k+1][1]
先走之前的子树回到根,再走下一颗子树也回根:
dp[i][j][0] + dp[son][k][0] –> dp[i][j+k+2][0]
先下一颗子树回到根,再走之前的子树不回根:
dp[i][j][1] + dp[son][k][0] –> dp[i][j+k+2][1]
这样的时间复杂度是 O(n^2)的。
代码
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int maxn=577;
int a[maxn],list[maxn],cnt,n,m,f[maxn][maxn][2];
struct E
{
int to,next;
}e[maxn*2];
void add(int u,int v)
{
e[++cnt].to=v; e[cnt].next=list[u]; list[u]=cnt;
}
void dfs(int u,int fa)
{
for(int i=list[u]; i; i=e[i].next)
{
int v=e[i].to;
if(v==fa) continue;
dfs(v,u);
for(int j=m; j>=2; j--)
{
for(int k=0; k<=m; k++)
{
if(j-k-2<0) break;
f[u][j][0]=max(f[u][j][0],f[u][j-k-2][0]+f[v][k][1]);
}
}
for(int j=m; j>=1; j--)
{
for(int k=0; k<=m; k++)
{
if(j-k-1<0)break;
f[u][j][0]=max(f[u][j][0],f[u][j-k-1][1]+f[v][k][0]);
}
}
for(int j=m; j>=2; j--)
{
for(int k=0; k<=m; k++)
{
if(j-k-2<0) break;
f[u][j][1]=max(f[u][j][1],f[u][j-k-1-1][1]+f[v][k][1]);
}
}
}
for(int j=m; j>=1; j--)
{
f[u][j][0]=max(f[u][j][0],f[u][j-1][0]+a[u]);
f[u][j][1]=max(f[u][j][1],f[u][j-1][1]+a[u]);
}
}
int main()
{
freopen("dostavljac.in","r",stdin); freopen("dostavljac.out","w",stdout);
scanf("%d%d",&n,&m);
for(int i=1; i<=n; i++) scanf("%d",&a[i]);
for(int i=1; i<=n-1; i++)
{
int x,y;
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
dfs(1,0);
printf("%d",f[1][m][0]);
}