题目模型:在n个节点的树上取m个节点,要取子节点,父节点必须取,所获得的权值的最大值。
题目 :http://blog.csdn.net/woshi250hua/article/details/7644959
或http://acm.hust.edu.cn/vjudge/contest/view.action?cid=23350#overview
这一类问题,被称作树型背包问题。以前也做过很多这样的题,都是看了方程把题目敲出来,把DP题当成模拟题来做了。网上这类题的题解都是直接给的二维方程,让我一直都不理解。看了DD的 《背包九讲》中的泛化物品,更是头大。
今天下决心要把这类问题弄明白,看了好多博客,终于明白这类方程是怎么来的了。
以HDU1561http://acm.hdu.edu.cn/showproblem.php?pid=1561为例,写一下这类问题的思考过程。
首先,二维状态可以定义为 f[ u ] [ j ]: 以u为根的子树,取j个节点(u必须取)的最优解。
方程应该很好写,为从u的所有儿子转移上来的状态和。
f [ u ][ j ] = max{ f[ u1 ][ j1 ]+f[ u2 ][ j2 ]+......+f [ uk ] [ jk ] } +value[ u ] ,其中u1,u2,...uk 为节点u的所有儿子节点,j1+j2+...+jk=j-1(u本身用去一个)。
直接转移的话要枚举j1 到jk,复杂度hold不住,也不好写,怎么办?
方法:加一维状态!
定义三维状态: dp [ u ][ i ][ j ] 表示:以u为根的子树,u的前i个子树共取 j 个节点的最优解。
这个状态与01背包的状态定义非常相似。其实仔细考虑求 max{ f[ u1 ][ j1 ]+f[ u2 ][ j2 ]+......+f [ uk ] [ jk ] },就是一个类似多重背包的问题,每个儿子节点取0到j个节点,节点个数和要等于j-1。
这样,就可以给出状态转移方程:
dp [ u ][ i ][ j ] = max {dp[ u ][ i ][ j ], dp[v][son[v].size][t-1]+value[v]+dp [ u ][ i-1 ][ j-t ] }
v=son[u][i]表示u的第i个儿子,表示第i个儿子取t个节点。
这里,son[a].size()表示a的儿子节点数量,i从1到sizeson[ u ].size()。
这样,这一题就在均摊O(n^3)的复杂度解决了。
但本题按上述方法写好后交上去MLE,我们只好降低空间复杂度。
观察发现,01背包的滚动数组方法,在本题中同样使用,即可以降i这一维!
可以写出方程:dp [ u ][ j ] = max {dp[ u ][ j ],dp[v][t-1]+value[v] +dp [ u ][ j-t ] }
j需要从m开始枚举!
到此,这一题完美解决!
具体实现时,为了保证当前根节点一定被选中,可以用个小技巧。
dp[i][j]表示不用考虑根结点,f[i][j]表示考虑根结点,方程见代码
贴个AC代码:
/*
* p1561.cpp
*
* Created on: 2013-5-29
* Author: zy
*/
#include<cstdio>
#include<algorithm>
#include<vector>
#include<cstring>
using namespace std;
const int maxn = 205;
int dp[maxn][maxn],f[maxn][maxn];
int value[maxn];
vector<int> son[maxn];
int n, m;
void dfs(int u)
{
if(son[u].size()==0)
{
f[u][1]=value[u];
return ;
}
for(int i=0;i<son[u].size();i++)
{
int v=son[u][i];
dfs(v);
for(int j=m;j>=0;j--)
for(int k=0;k<=j;k++)
{
dp[u][j]=max(dp[u][j],f[v][k]+dp[u][j-k]);
}
}
for(int j=1;j<=m;j++)
f[u][j]=dp[u][j-1]+value[u];
}
int main()
{
while (scanf("%d%d", &n, &m) && n )
{
int x, y;
for (int i = 0; i <= n; i++)
son[i].clear();
for (int i = 1; i <= n; i++)
{
scanf("%d%d", &x, &value[i]);
son[x].push_back(i);
}
memset(dp, 0, sizeof(dp));
memset(f,0,sizeof(f));
dfs(0);
printf("%d\n",dp[0][m]);
}
return 0;
}