树形dp详解
一、树形dp
什么是树形dp?
树形dp是建立在树状结构的基础上的一种dp。其实个人感觉和记忆化搜索有点像。
树形dp的遍历顺序
一般的树形dp都是先找到树根,用dfs先初始化,然后递归到叶节点,然后退回来更新dp数组。
void dfs(int u,int fa){
dp[u] = ...;//初始化
for(int i = head[u];i;i = e[i].next) if(e[i].v != fa){
dfs(e[i].v,u);//先到子节点
dp[u] = ...;//转移方程
}
}
二、树形dp练习
洛谷 P1352
其实看懂题目后也不难,也就是n个节点,每个节点有一个数值,这些节点构成了一棵树,让你选一个集合,使他们的和最大,并且有子节点就不能有父节点,有父节点就不能有子节点。
设dp[u][0]为不取u节点的数值,dp[u][1]为取u节点的数值。
那么就有以下转移方程:
{
d
p
[
u
]
[
0
]
+
=
m
a
x
(
d
p
[
v
]
[
1
]
,
d
p
[
v
]
[
0
]
)
d
p
[
u
]
[
1
]
+
=
d
p
[
v
]
[
0
]
\begin{cases} dp[u][0] \ += \ max(dp[v][1],dp[v][0])\\ dp[u][1] \ += \ dp[v][0]\\ \end{cases}
{dp[u][0] += max(dp[v][1],dp[v][0])dp[u][1] += dp[v][0]
时间复杂度
O
(
n
)
O(n)
O(n)。
注意本题不是以1为根节点,需要寻找根节点,没有上司的结点即为根节点,读入时用数组标记即可。
/*
*/
#include<bits/stdc++.h>
#define rep(i,s1,s2,s3) for(i = s1;i <= s2;i += s3)
#define r(i,s1,s2,s3) for(i = s1;i >= s2;i -= s3)
#define ull unsigned long long
#define sort stable_sort
#define INF 0x7f7f7f7f
#define ll long long
using namespace std;
int n,id,r[6010],in[6010],dp[6010][2],head[6010];
struct node{
int u,v,next;
}e[6010];
void add(int u,int v){
e[++id] = node{u,v,head[u]};
head[u] = id;
}
void dfs(int u){
dp[u][0] = 0;
dp[u][1] = r[u];
int v;
for(int i = head[u];i;i = e[i].next){
v = e[i].v;
dfs(v);
dp[u][1] += dp[v][0];
dp[u][0] += max(dp[v][1],dp[v][0]);
}
}
int main(){
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
//freopen(".in","r",stdin);
//freopen(".out","w",stdout);
cin>>n;
int i,u,v;
rep(i,1,n,1) cin>>r[i];
rep(i,1,n - 1,1){
cin>>v>>u;
add(u,v);
in[v]++;
}
rep(i,1,n,1) if(!in[i]){
dfs(i);
cout<<max(dp[i][0],dp[i][1]);
break;
}
return 0;
}
洛谷 P2016
其实和上面这题差不多,但它是求最小值。
/*
*/
#include<bits/stdc++.h>
#define rep(i,s1,s2,s3) for(i = s1;i <= s2;i += s3)
#define r(i,s1,s2,s3) for(i = s1;i >= s2;i -= s3)
#define ull unsigned long long
#define sort stable_sort
#define INF 0x7f7f7f7f
#define ll long long
using namespace std;
int n,id,dp[100010][2],head[100010];
struct edge{
int u,v,next;
}e[200010];
void add(int u,int v){
e[++id] = edge{u,v,head[u]};
head[u] = id;
}
void dfs(int u,int fa){
dp[u][1] = 1;
int v;
for(int i = head[u];i;i = e[i].next) if(e[i].v != fa){
v = e[i].v;
dfs(v,u);
dp[u][0] += dp[v][1];
dp[u][1] += min(dp[v][0],dp[v][1]);
}
}
int main(){
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
//freopen(".in","r",stdin);
//freopen(".out","w",stdout);
cin>>n;
int i,l,u,v;
rep(i,1,n,1){
cin>>u>>l;
while(l--){
cin>>v;
add(u,v);
add(v,u);
}
}
dfs(0,-1);
cout<<min(dp[0][0],dp[0][1]);
return 0;
}