总结
树形dp主要是先建立一个树图, 对于每个结点定义一个状态,进行转移,通常用递归,通过类似深度优先遍历的方法,使得在回溯时更新全部结点
例题
1. 没有上司的舞会(简单dp)洛谷P1352
题意: 有n个人,每个人之间有从属关系,每个人不想与上司之间同时参加舞会,每个人有一个快乐指数,问能参加舞会最大的快乐指数?
解:定义
d
p
[
u
]
[
0
]
、
d
p
[
u
]
[
1
]
dp[u][0]、dp[u][1]
dp[u][0]、dp[u][1]分别为 所有以u为根的子树方案中不选u/选u的方案的最大值。通过递归回溯最先遍历到叶子结点,向前回溯,逐步更新前面结点。
选该结点时子节点不能够选, 不选时取max
状态转移方程
d
p
[
u
]
[
0
]
=
∑
m
a
x
(
d
p
[
s
1
]
[
0
]
,
d
p
[
s
1
]
[
1
]
)
dp[u][0]=\sum max(dp[s_1][0],dp[s_1][1])
dp[u][0]=∑max(dp[s1][0],dp[s1][1])
d
p
[
u
]
[
1
]
=
∑
d
p
[
s
1
]
[
0
]
dp[u][1] = \sum dp[s_1][0]
dp[u][1]=∑dp[s1][0]
#include <bits/stdc++.h>
using namespace std;
const int N = 6e3 + 5;
int n, happy[N];
int h[N], e[N], ne[N], idx; //建树
int dp[N][2];
bool has_father[N];
void add(int a, int b) //添加树的边
{
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
void dfs(int u)
{
dp[u][1] = happy[u]; //先初始化取该结点是快乐度
for (int i = h[u]; i != -1; i = ne[i]){ //遍历子树
int j = e[i];
dfs(j); //深度优先遍历
dp[u][0] += max(dp[j][0], dp[j][1]); //回溯时更新
dp[u][1] += dp[j][0];
}
}
int main()
{
cin >> n;
for (int i = 1; i <= n; i++) cin >> happy[i];
memset(h, -1, sizeof(h));
for (int i = 1; i < n; i++){
int l, k;
cin >> l >> k;
has_father[l] = true;
add(k, l);
}
int root = 1;
while(has_father[root]) root++; //找到根节点
dfs(root);
cout << max(dp[root][0], dp[root][1]) << endl;
return 0;
}
2.选课(树上背包)洛谷P2014
题意:有一个树形结构的选课图,只有选了课的先修课才能选该课,每门课有固定学分,问在给定的选课数下能选到的最多学分?
思想:
1.由于有的课有先修课程,有的课无,可看作一个森林,全部课都可看成有一个0结点的先修课,则需求m+1门课,从0开始dfs
2.从非叶子结点开始枚举,可将一个结点的子树看成一个分组,进行分组背包,由于每个结点只对应其下的状态信息,如果要使用该结点的dp信息,就需要使用该结点,因此可对应分组背包做法
3.定义状态
d
p
[
u
]
[
j
]
dp[u][j]
dp[u][j]表示以u为根节点的子树,并且底下有最多j个结点(不包括u点)的可获得的最大值。每次使用1个结点可看作是费用,结点上权值可看作价值。
4.状态转移方程同分组背包,第一步枚举分组,在枚举体积(倒序),最后枚举背包内的内容。
状态转移方程:
d
p
[
u
]
[
j
]
=
m
a
x
(
d
p
[
u
]
[
j
]
,
d
p
[
u
]
[
j
−
k
−
1
]
+
d
p
[
e
[
i
]
]
[
k
]
+
w
[
i
]
)
dp[u][j] = max(dp[u][j], dp[u][j - k - 1] + dp[e[i]][k]+w[i])
dp[u][j]=max(dp[u][j],dp[u][j−k−1]+dp[e[i]][k]+w[i])
在不选该分组、选该分组之间取max
选该分组一定要选该点,从底下1个结点枚举到底下m个结点,更新最大值,更新完即可
#include <bits/stdc++.h>
using namespace std;
const int N = 305;
int n, m;
int h[N], w[N], e[N], ne[N], idx;
int dp[N][N]; //以u为根节点结点,选j个点
void add(int a, int b, int c)
{
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++;
}
void dfs(int u)
{
for (int i = h[u]; ~i; i = ne[i]){
dfs(e[i]);
for (int j = m; j >= 1; j --)
for (int k = 0; k < j; k ++)
dp[u][j] = max(dp[u][j], dp[u][j - k - 1] + dp[e[i]][k] + w[i]);
}
}
int main()
{
memset(h, -1, sizeof(h));
cin >> n >> m;
for (int i = 1; i <= n; i ++){
int a, b;
cin >> a >> b;
add(a, i, b);
}
dfs(0);
cout << dp[0][m];
return 0;
}