希望在看这篇文章前你已经对基础的dp有了一定的了解,就算学的不好也能看懂这篇讲解
我们先以一道例题开头:
树形dp知名入门例题:没有上司的舞会
题目链接
树形dp就是以树为基础的dp(树就不用我说了吧,不关是数据结构还是离散数学好像都要学)
这里我们说说怎么存树;
#include<iostream>
#include<math.h>
using namespace std;
const int N=6000+10;
int last[N],ne[N],edge[N],cnt;
bool biao[N];
void add(int a, int b){
edge[cnt] = b;
ne[cnt] = last[a];
last[a] = cnt++;
}
int main()
{
int n;
cin>>n;//n名员工
int a,b;
for(int i=1;i<n;i++)//n-1条边
{cin>>a>>b;//b是a的直接上司
add(a,b);
}
}
我们模拟一下,假设有这么一棵树
假设我们需要找1的子节点
通过last[a]这个数组,我们可以得到一个cnt,通过这个cnt可以dge[cnt] = b数组得到它的一个子节点b
void add(int a, int b){
edge[cnt] = b;
ne[cnt] = last[a];
last[a] = cnt++;
}
注意这段代码cnt++是最后执行的
其实我们也可以这么写
void add(int a, int b){
cnt++;//此时cnt是从一开始,遍历时需要注意一下
edge[cnt] = b;
ne[cnt] = last[a];
last[a] = cnt;
}
我们甚至还可以这么写
void add(int a, int b){
edge[cnt] = b;
ne[cnt] = last[a];
last[a] = cnt;
cnt++;
}
我们注意在最后cnt的值服给last[a]时,last[a]的值是上一次a作为父节点时cnt的值,而我们在改变last[a]为此时cnt的值之前我们已经把它赋值给了ne[cnt] 而这个cnt又赋值给了last[a];
总结下:通过last[a]得到cnt,通过edge[cnt]可以得到一直子节点,通过ne[cnt]可以得到上一次以a作为父节点时的cnt,以此类推
学会了建树,我们可以开始看看这道题了;
还是这课树
假设我们选了1,所求的最大值就是左边子树不选2时的最大值加上右边子树不选3时的最大值最后加上1的权值;
假设我们不选1,所求最大值就是max(左边子树选2,左边子树不选2)+max(右边子树选3,右边子树不选3);
就得到了节点1的最大值,所有节点都可以根据这个方程求得该状态的最大值;
下面我们看看代码:
c++代码
#include<iostream>
#include<math.h>
#include<string.h>
using namespace std;
const int N=6000+10;
int last[N];
int ne[N],edge[N],cnt=1;
bool biao[N];
void add(int a, int b){
edge[cnt] = b;
ne[cnt] = last[a];
last[a] = cnt++;
}
int dp[N][2];
int a[N];
void dps(int root)
{
dp[root][0]=0;//不选root号节点;
dp[root][1]=a[root];//选root号节点
for(int i=last[root];i>=1;i=ne[i])
{
int j=edge[i];
dps(j);
dp[root][0]+=max(dp[j][0],dp[j][1]);
dp[root][1]+=dp[j][0];
}
}
int main()
{
int n;
cin>>n;
for(int i=1;i<=n;i++)
cin>>a[i];
//因为0包含的有边
int x,y;
for(int i=1;i<n;i++)//n-1条边
{cin>>x>>y;//y是x的直接上司
add(y,x);
biao[x]=true;//a有父节点b
}
int v=1;
while(biao[v])v++;//找到没有父节点的点
dps(v);
cout<<max(dp[v][1],dp[v][0]);
}
java代码:
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Scanner;
import java.util.Set;
public class Main{
static int N = 6010;
static int[] happy = new int[N];
static int[] h = new int[N];
static int[] e = new int[N];
static int[] ne = new int[N];
static int idx = 0;
static int[][] f = new int[N][2];
static boolean[] has_fa = new boolean[N];
static void add(int a,int b)
{
e[idx] = b;
ne[idx] = h[a];
h[a] = idx ++;
}
static void dfs(int u)
{
f[u][1] = happy[u];
for(int i = h[u]; i != -1;i = ne[i])
{
int j = e[i];
dfs(j);
f[u][1] += f[j][0];
f[u][0] += Math.max(f[j][1], f[j][0]);
}
}
public static void main(String[] args) {
Scanner scan = new Scanner(System.in);
int n = scan.nextInt();
for(int i = 1;i <= n;i ++) happy[i] = scan.nextInt();
Arrays.fill(h, -1);
for(int i = 0;i < n - 1;i ++)
{
int a = scan.nextInt();
int b = scan.nextInt();
add(b,a);
has_fa[a] = true;
}
int root = 1;
while(has_fa[root]) root ++;
dfs(root);
System.out.println(Math.max(f[root][0], f[root][1]));
}
}
同理我们看看这道题:
题目链接
#include<iostream>
#include<cstring>
using namespace std;
typedef long long LL;
const int N = 100010,M=N*2;
int h[N],e[M],ne[M],w[N],idx;
int n;
LL f[N];
void add(int a,int b)
{
e[idx]=b;ne[idx]=h[a];h[a]=idx++;
}
void dfs(int u,int father)
{
f[u]=w[u];
for(int i=h[u];~i;i=ne[i])
{
int j = e[i];
if(j!=father)//不能让它倒回去
{
dfs(j,u);
f[u]+=max(0ll,f[j]);
}
}
}
int main()
{
cin>>n;
memset(h,-1,sizeof h);
for(int i=1;i<=n;++i)cin>>w[i];
for(int i=0;i<n-1;++i){
int a,b;
cin>>a>>b;
add(a,b);
add(b,a);
}
dfs(2,-1);
LL res = f[1];
for(int i=2;i<=n;++i)res = max(res,f[i]);
cout<<res<<endl;
return 0;
}
java代码:
import java.util.*;
public class Main {
private static int n;
private static long res;
private static long[] w;
private static List<Integer>[] g;
private static long[] f ;
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
n =sc.nextInt();
w = new long[n+1];
g = new ArrayList[n+1];
f = new long[n+1];
for(int i = 0;i < n+1; i++){
g[i] = new ArrayList<Integer>();
}
for(int i = 1;i <= n;i ++){
w[i] = sc.nextLong();
}
for(int i = 0;i < n - 1;i ++)
{
int a = sc.nextInt();
int b = sc.nextInt();
g[a].add(b);
g[b].add(a);
}
dfs(1,0);
res = f[1];
for(int i = 2; i < n+1; i++){
if(res<f[i]){
res = f[i];
}
}
System.out.println(res);
}
/**
*root作为根所代表的子树有一个最大权和,将其存储在f[root]中
*/
private static void dfs(int root,int father){
f[root] = w[root];
for(int i = 0; i < g[root].size(); i++){
Integer child = g[root].get(i);
if(child == father){
continue;
}
dfs(child,root);
if(f[child] > 0){
f[root] += f[child];
}
}
}
}