You're given a tree with weights of each node, you need to find the maximum subtree of specified size of this tree.
Tree Definition
A tree is a connected graph which contains no cycles.
Input
There are several test cases in the input.
The first line of each case are two integers N(1 <= N <= 100), K(1 <= K <= N), where N is the number of nodes of this tree, and K is the subtree's size, followed by a line with N nonnegative integers, where the k-th integer indicates the weight of k-th node. The following N - 1 lines describe the tree, each line are two integers which means there is an edge between these two nodes. All indices above are zero-base and it is guaranteed that the description of the tree is correct.
Output
One line with a single integer for each case, which is the total weights of the maximum subtree.
Sample Input
3 1 10 20 30 0 1 0 2 3 2 10 20 30 0 1 0 2
Sample Output
30 40
题意:
从一颗 n 个节点的树中选出含有 k 个节点的子树,求子树最大权值
分析:
树形背包dp
dp[ fa ][ x ] 表示以 fa 为根节点,选 x 个节点的最大权值
则状态转移方程为:dp[ fa ][ fax ] = max(dp[ fa ][ fax ], dp[ fa ][ fax - sonx ] + dp[ son ][ sonx ])( 1 < sonx < fax )
对于 fa 的每个孩子,可以选择 1,2,3...fax - 1个节点分配给它
代码:
#include <cstdio>
#include <cstring>
#include <vector>
#include <iostream>
#include <algorithm>
using namespace std;
const int MAXN = 100 + 5;
int n, k;
vector<int> adj[MAXN]; //邻接表,adj[x] 中保存所有与 x 相连的节点
int dp[MAXN][MAXN], vis[MAXN], weight[MAXN];
void dfs(int fa)
{
dp[fa][1] = weight[fa]; //以fa为根,只有一个节点的话结果就是它本身
vis[fa] = 1; //标记为已经访问
int len = adj[fa].size(); //得到与它相连的节点的个数
for (int i = 0; i < len; i++)
{
int son = adj[fa][i]; //得到子节点编号
if (!vis[son]) //如果还未访问
{
dfs(son); //先递归子节点求值
for (int fax = k; fax > 0; fax--) //状态转移方程
for (int sonx = 1; sonx < fax; sonx++)
dp[fa][fax] = max(dp[fa][fax], dp[fa][fax - sonx] + dp[son][sonx]);
}
}
}
int main()
{
while (~scanf("%d%d", &n, &k))
{
for (int i = 0; i < n; i++)
{
scanf("%d", &weight[i]);
adj[i].clear(); //先将邻接表清空
}
for (int i = 1; i < n; i++)
{
int x, y;
scanf("%d%d", &x, &y);
adj[x].push_back(y); //加入各自的表中
adj[y].push_back(x);
}
memset(dp, 0, sizeof(dp));
memset(vis, 0, sizeof(vis));
dfs(0); //从节点 0 开始深搜
int ans = 0;
for (int i = 0; i < n; i++)
ans = max(ans, dp[i][k]);
printf("%d\n", ans);
}
return 0;
}