树形依赖背包指的就是一类具有树形依赖关系的背包问题。当选一个物品的前提是选另一件物品,而这些依赖关系构成了一个树形关系。在容量有限的情况下,然后求最大的价值,这类问题我们就称之为树形依赖背包。
树形依赖背包问题实际上是一类分组背包问题,我们可以将每个点的子树看成一个组,因为子树内会选择一定的点,但是选择的点数只有一种情况,所以我们可以将子树选择 i i i( i ∈ [ 0 , s i z s o n x ] i∈[0,siz_{sonx}] i∈[0,sizsonx])的情况看成一类物品,然后就做分组背包即可。
我们设 f [ i ] [ j ] f[i][j] f[i][j]表示点 i i i,字数里选了 j j j容量的最大价值,那么则有DP方程:
for(int i=k;i>=w[x];i--)
for(int j=0;j<=k;j++){
if(i+j>k) break;
f[x][i+j]=max(f[x][i+j],f[x][i]+f[son][j])
}
一般的依赖关系会要求选择当前子树的树根,因此初始状态是f[x][w[x]]=v[x]
,并且在第一维枚举容量时只枚举到w[x]
。
还有一类DP方程, f [ i ] [ j ] f[i][j] f[i][j]表示点 i i i字数内选择了 j j j个点所获得的最大的收益,那么方程式也稍微变一下即可:
for(int i=siz[x];i>=1;i--)
for(int j=0;j<=siz[son];j++){
if(i+j>k) break;
f[x][i+j]=max(f[x][i+j],f[x][i]+f[son][j])
}
这类方程有一个很实用的优化,就是我们刚开始初始化siz[x]=1
,然后每次枚举完一个儿子的时候,就执行语句siz[x]+=siz[son]
。这样可以大大降低复杂度,有巨佬指出,原来不进行优化的时间复杂度为 O ( n 3 ) O(n^3) O(n3),但是进行了这个 s i z siz siz数组的优化后,可以接近于 O ( n 2 ) O(n^2) O(n2)这也就是为什么许多 n n n为 2500 2500 2500的题目,也可以用树形依赖背包来做。
例题:
洛谷:P2014选课
题目描述
在大学里每个学生,为了达到一定的学分,必须从很多课程里选择一些课程来学习,在课程里有些课程必须在某些课程之前学习,如高等数学总是在其它课程之前学习。现在有N门功课,每门课有个学分,每门课有一门或没有直接先修课(若课程a是课程b的先修课即只有学完了课程a,才能学习课程b)。一个学生要从这些课程里选择M门课程学习,问他能获得的最大学分是多少?
输入输出格式
输入格式:
第一行有两个整数N,M用空格隔开。(1<=N<=300,1<=M<=300)
接下来的N行,第I+1行包含两个整数ki和si, ki表示第I门课的直接先修课,si表示第I门课的学分。若ki=0表示没有直接先修课(1<=ki<=N, 1<=si<=20)。
输出格式:
只有一行,选M门课程的最大得分。
输入输出样例
输入样例#1:
7 4
2 2
0 1
0 4
2 1
7 1
7 6
2 2
输出样例#1:
13
这道题就是最基本的树形依赖背包模板,就是上面说的第二类DP方程,甚至由于N的范围很小,直接 O ( n 3 ) O(n^3) O(n3)的DP也不会超时
代码中的方程式与上文所述的方程式效果相同,但是一般会选用上文的那种,因为那种方程式配合 s i z siz siz优化的时候复杂度更优
#include<bits/stdc++.h>
#define MAXN 205
using namespace std;
int read(){
char c;int x;while(c=getchar(),c<'0'||c>'9');x=c-'0';
while(c=getchar(),c>='0'&&c<='9') x=x*10+c-'0';return x;
}
int n,m,cnt,siz[MAXN],head[MAXN],nxt[MAXN],f[MAXN][105];
struct node{
int to,val;
}L[MAXN];
void add(int x,int y,int c){
L[cnt]=(node){y,c};
nxt[cnt]=head[x];head[x]=cnt;cnt++;
}
void get(int x){
siz[x]=1;
for(int i=head[x];i!=-1;i=nxt[i]){
int to=L[i].to;
get(to);siz[x]+=siz[to];
}
}
int dfs(int x){
for(int i=head[x];i!=-1;i=nxt[i]){
int to=L[i].to;dfs(to);
for(int j=siz[x]-1;j>=0;j--)
for(int k=0;k<=siz[to]-1;k++)
if(j>k) f[x][j]=max(f[x][j],f[x][j-k-1]+L[i].val+f[to][k]);
}
}
int main()
{
n=read();m=read();
memset(head,-1,sizeof(head));
memset(f,~0x3f,sizeof(f));
for(int i=1;i<=n;i++){
int x=read()+1,c=read();
add(x,i+1,c);
}
for(int i=1;i<=n+1;i++) f[i][0]=0;
get(1);dfs(1);
printf("%d",f[1][m]);
return 0;
}