树形DP(这里讲解以LOJ 10153以及LOJ 10154)
概念
顾名思义,就是在树上进行的动态规划,二叉树用左右儿子数组和邻接矩阵存储,一般的树用vector和邻接矩阵存储。
具体方法
二叉树的方法
邻接矩阵的建立(一般人都会吧)
首先要用到一个很高级的函数:scanf(或者cin)读入两个点的连边和边的权值,分别是x、y和cost,并用邻接数组——map把他们存下来——map[x][y]=z。
代码如下:
for (int i=1;i<n;i++){
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
mp[x][y]=mp[y][x]=z;
}
建树的过程(树形DP中最难的部分)
要先去建立左儿子,还有结点可以放就放到右儿子处,放的时候要用左/右儿子边连上其编号。
代码也不长,在例题里有点“复杂”:
void build(int i){
for (int i=1;i<=n;i++){
if (mp[v][i]>=0){ //要先把mp初始化成-1
lson[v]=i;
a[i]=mp[v][i];
build(i);
break;
}
}
for (int i=1;i<=n;i++){
if (mp[v][i]>=0){
rson[v]=i;
a[i]=mp[v][i];
build(i);
break;
}
}
}
询问
最简单了,推出转移方程,在用一个t数组记下来求最大值就可以了。
代码:
int solve(int i,int j){
if (j<=0)return 0;
if (lson[i]==0&&rson[i]==0)return 0;
if (f[i][j]>0)return f[i][j];
for (int k=0;k<=j;k++){
int t=0;
if (k!=0)t+=转移方程1;
if (k!=j)t+=转移方程2;
f[i][j]=max(f[i][j],t);
}
return f[i][j];
}
一般树的方法
一般的树因为已经用vector存好了,所以就可以免去建树的过程,于是,就剩下了一个询问的过程,我们这次的询问和前面的完全不同,转移是有点类似于背包DP的,实际上他就是背包DP,没人告诉我这是背包,我看状态很像背包的……
于是,一大段美丽的代码出炉了:
void solve(int x){
f[x][0]=0;
for (int i=0;i<son[x].size();i++){
int y=son[x][i];
solve(y);
for (int j=m;j>=1;j--){
for (int k=j;k>=1;k--){
f[x][j]=max(f[x][j],f[x][j-k]+f[y][k]);
}
}
}
if (x!=0){
for (int i=m;i>=1;i--){
f[x][i]=f[x][i-1]+s[x];
}
}
}
例题:二叉苹果树
这道题,我们要把树枝上的苹果都压到结点上去,把这个剪线问题改成吃点问题,不仅好理解,代码也好写。
然后我们在建树的时候,可以直接把吃点的过程写上去,这样就省去了在查询时不断吃点,导致状态难找,f数组过大的情况。
所以说,在建树时,要直接把mp[i][v]赋值成-1就可以实现剪枝。
至于转移方程,那就简单了,只要把左子树的答案加上右子树的答案再加上左右子树的苹果树,就可以了。
注意:只有在左子树有的时候才能遍历左子树,同理,要在有右子树的室友才能遍历右子树
代码:
#include<bits/stdc++.h>
using namespace std;
int n,q,lson[105],rson[105],f[105][105],mp[105][105],a[105];
void build(int v){
for (int i=1;i<=n;i++){
if (mp[v][i]>=0){
lson[v]=i;
a[i]=mp[v][i];
mp[v][i]=-1;
mp[i][v]=-1;
build(i);
break;
}
}
for (int i=1;i<=n;i++){
if (mp[v][i]>=0){
rson[v]=i;
a[i]=mp[v][i];
mp[v][i]=-1;
mp[i][v]=-1;
build(i);
break;
}
}
}
int solve(int i,int j){
if (j<=0)return 0;
if (lson[i]==0&&rson[i]==0)return 0;
if (f[i][j]>0)return f[i][j];
for (int k=0;k<=j;k++){
int t=0;
if (k!=0)t+=solve(lson[i],k-1)+a[lson[i]];
if (k!=j)t+=solve(rson[i],j-k-1)+a[rson[i]];
f[i][j]=max(f[i][j],t);
}
return f[i][j];
}
int main(){
scanf("%d%d",&n,&q);
for (int i=1;i<=n;i++){
for (int j=1;j<=n;j++){
mp[i][j]=-1;
}
}
for (int i=1;i<n;i++){
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
mp[x][y]=mp[y][x]=z;
}
build(1);
printf("%d\n",solve(1,q));
return 0;
}
如有错误,请发至评论区,我会仔细看并更改博客