树形DP简单总结一下就是下面几个步骤
- 前向星建树
- 定状态
- 确定状态转移方程
- 初始化
- DFS 进行状态转移,这里要注意是 从上向下递推还是从下往上递推
- ps:这里主要是挑了几个自认为比较典型的例题吧,难度不大,重点是熟悉一下树形DP的思想与步骤,虽然一些题目的状态转移方程比较好想,但千万不要眼高手低,最好亲手敲代码实现一下,要特别注意边界问题,我因为边界和初始化的问题,这里面也有的题目 WA 了好几发(还是菜)
典型例题
HDU 4118 Holiday’s Accommodation
有一棵树(n≤1e5),树上每个节点都是一座房子,房子里有人, 树的每条边都有一个权值(距离),所有人都要去别人家玩(通过最短路径),有两个要求:
- 所有人都要去别人的房子(不能不动)
- 不能同时有两个人去同一个房子
求这些人移动距离总和的最大值
- 一道树形DP的入门题,主要学习一下 单条边对于全局的贡献 这一思想
摘自(SDU程序设计思维与实践 Week13-动态规划(四) P29)
#pragma comment(linker,"/STACK:102400000,102400000")
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn=1e5+10;
typedef long long ll;
ll ans;
int t,n,x,y,z,tot,head[maxn],cnt[maxn];
struct Edge{
int to,next,w;
}e[2*maxn];
void init(int nn){
ans=0;tot=0;
for(int i=0;i<=nn;i++){
cnt[i]=0;
head[i]=-1;
}
}
void add(int a,int b,int c){
e[++tot].to=b;
e[tot].w=c;
e[tot].next=head[a];
head[a]=tot;
}
void solve(int x,int fa){
cnt[x]++;
for(int i=head[x];i!=-1;i=e[i].next){
int y=e[i].to,w=e[i].w;
if(y==fa) continue;
solve(y,x);
cnt[x]+=cnt[y];
ans+=(ll)2*w*min(cnt[y],n-cnt[y]);
}
}
int main()
{
scanf("%d",&t);
for(int ccase=1;ccase<=t;ccase++){
scanf("%d",&n);
init(n);
for(int i=1;i<n;i++){
scanf("%d%d%d",&x,&y,&z);
add(x,y,z);
add(y,x,z);
}
solve(1,-1);
printf("Case #%d: %lld\n",ccase,ans);
}
return 0;
}
- ps:这道题会爆栈,所以要么改递推要么扩容
洛谷P1352 没有上司的舞会
公司共有n(n ≤ 6000)位员工。公司要举行一个舞会。为了让到会的每个人不受他的直接上司约束而能玩得开心,公司领导决定:
- 如果邀请了某个人,那么一定不会再邀请他的直接的上司,但该人的上司的上司,上司的上司的上司……都可以邀请。
- 已知每个人最多有唯一的一个上司。
- 公司的每个人参加晚会都能为晚会增添一些气氛,求一个邀请方案, 使气氛值的和最大。
- ps: 有人的气氛值为负数
思路
- d p [ x ] [ 0 ] dp[x][0] dp[x][0] 表示以 x x x 为根的子树且未选择 x x x 号节点时的最大值
- d p [ x ] [ 1 ] dp[x][1] dp[x][1] 表示以 x x x 为根的子树且选择 x x x 号节点时的最大值
- 状态转移方程:
- d p [ x ] [ 0 ] + = m a x ( d p [ y ] [ 0 ] , d p [ y ] [ 1 ] ) dp[x][0]+=max(dp[y][0],dp[y][1]) dp[x][0]+=max(dp[y][0],dp[y][1])(不选 x x x 则选不选 y y y 都可)
- d p [ x ] [ 1 ] + = m a x ( 0 , d p [ y ] [ 0 ] ) dp[x][1]+=max(0,dp[y][0]) dp[x][1]+=max(0,dp[y][0]) (选 x x x 则肯定不能选 y y y)
- y y y 是节点 x x x 能直接到达的所有的子节点
- 最终答案就是 m a x ( d p [ r o o t ] [ 0 ] , d p [ r o o t ] [ 1 ] ) max(dp[root][0],dp[root][1]) max(dp[root][0],dp[root][1])
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn=6010;
int n,ans,tot,root,head[maxn],dp[maxn][2];
bool father[maxn];
struct edge{
int to,next;
}e[2*maxn];
void init(){
ans=tot=0;
memset(dp,0,sizeof(dp));
memset(father,0,sizeof(father));
memset(head,-1,sizeof(head));
}
void add(int x,int y){
e[++tot].to=y;
e[tot].next=head[x];
head[x]=tot;
}
void solve(int x){
for(int i=head[x];i!=-1;i=e[i].next){
int y=e[i].to;
solve(y);
dp[x][0]+=max(dp[y][0],dp[y][1]);
dp[x][1]+=max(0,dp[y][0]);
}
}
int main()
{
scanf("%d",&n);
init();
for(int i=1;i<=n;i++){
scanf("%d",&dp[i][1]);
}
int a,b;
for(int i=1;i<n;i++){
scanf("%d%d",&a,&b);
add(b,a);
father[a]=1;
}
for(int i=1;i<=n;i++){
if(!father[i]){
root=i;
break;
}
}
solve(root);
ans=max(dp[root][0],dp[root][1]);
printf("%d\n",ans);
return 0;
}
UVA1292 Strategic game
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn=1510;
int n,tot,head[maxn],dp[maxn][2];
struct edge{
int to,next;
}e[2*maxn];
void init(int nn){
tot=0;
for(int i=0;i<=nn;i++){
head[i]=-1;
dp[i][0]=0;dp[i][1]=1;
}
}
void add(int x,int y){
e[++tot].to=y;
e[tot].next=head[x];
head[x]=tot;
}
void solve(int x,int f){
for(int i=head[x];i!=-1;i=e[i].next){
int y=e[i].to;
if(y==f) continue;
solve(y,x);
dp[x][0]+=dp[y][1];
dp[x][1]+=min(dp[y][0],dp[y][1]);
}
}
int main()
{
while(~scanf("%d",&n)){
init(n);
int a,b,c;
for(int i=0;i<n;i++){
scanf("%d:(%d)",&a,&b);
while(b--){
scanf("%d",&c);
add(a,c);
add(c,a);
}
}
solve(0,-1);
printf("%d\n",min(dp[0][0],dp[0][1]));
}
return 0;
}
洛谷P1122 最大子树和
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn=16010;
int n,tot,ans,head[maxn],w[maxn],dp[maxn];
struct edge{
int to,next;
}e[2*maxn];
void init(){
ans=0;
tot=0;
memset(head,-1,sizeof(head));
memset(dp,0,sizeof(dp));
}
void add(int x,int y){
e[++tot].to=y;
e[tot].next=head[x];
head[x]=tot;
}
void solve(int x,int fa){
dp[x]=w[x];
for(int i=head[x];i!=-1;i=e[i].next){
int y=e[i].to;
if(y==fa) continue;
solve(y,x);
dp[x]+=max(0,dp[y]);
}
ans=max(ans,dp[x]);
}
int main()
{
scanf("%d",&n);
init();
for(int i=1;i<=n;i++)
scanf("%d",&w[i]);
int a,b;
for(int i=1;i<=n-1;i++){
scanf("%d%d",&a,&b);
add(a,b);
add(b,a);
}
solve(1,-1);
printf("%d\n",ans);
return 0;
}
树的直径
HDU 2196 Computer
实验室里原先有一台电脑(编号为1),最近实验室又购置了 N-1 台电脑,编号为 2 到 N 。每台电脑都用网线连接到一台先前安装的电脑上,求第 i 台电脑到其他电脑的最大网线长度。
解法一
根据输入数据建树,假设树的最长路的两个叶子结点为 v 1 , v 2 v1,v2 v1,v2,这道题要求找到某个结点 x x x 所能到达的最长路径,那么这个结点 x x x 的最长路径要么是到 v 1 v1 v1的路径,要么就是到 v 2 v2 v2的路径,所以首先需要从任意结点开始执行 DFS 找到最远结点 v 1 v1 v1,然后再以 v 1 v1 v1 为源点执行 DFS 找到另一个最远结点 v 2 v2 v2 同时记录 v 1 v1 v1 到各点的最长路径 d 1 [ N ] d1[N] d1[N],最后再以 v 2 v2 v2 为源点执行 DFS,同时记录 v 2 v2 v2 到各点的最长路径 d 2 [ N ] d2[N] d2[N],则某个结点 x x x 所能到达的最长路径即为 m a x ( d 1 [ x ] , d 2 [ x ] ) max(d1[x],d2[x]) max(d1[x],d2[x])。
代码实现
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cstring>
using namespace std;
const int maxn=100010;
struct Edge{
int u,v,w,next;
}Edges[maxn];
int head[maxn],d1[maxn],d2[maxn],tot,vv;
bool vis[maxn];
void init(int n){
tot=0;
for(int i=0;i<n;i++)
head[i]=-1;
}
void addEdge(int u,int v,int w){
Edges[tot].u=u;
Edges[tot].v=v;
Edges[tot].w=w;
Edges[tot].next=head[u];
head[u]=tot;
tot++;
}
void dfs(int u,int path,int *d){
d[u]=path;
if(d[vv]<d[u])
vv=u;
for(int i=head[u];i!=-1;i=Edges[i].next){
if(!vis[Edges[i].v]){
vis[Edges[i].v]=true;
dfs(Edges[i].v, path+Edges[i].w, d);
}
}
}
int main()
{
int n,u,v,w;
while(~scanf("%d",&n)){
init(n+1);
for(int i=2;i<=n;i++){
scanf("%d%d",&v,&w);
addEdge(i,v,w);
addEdge(v,i,w);
}
//找到最远点v1
memset(vis,0,sizeof(vis));
d1[vv]=0;
vis[1]=1;
dfs(1,0,d1);
//找到最远点v2
memset(vis,0,sizeof(vis));
d1[vv]=0;
vis[vv]=1;
dfs(vv,0,d1);
//v2执行dfs
memset(vis,0,sizeof(vis));
d2[vv]=0;
vis[vv]=1;
dfs(vv,0,d2);
for(int i=1;i<=n;i++){
int maxd=max(d1[i],d2[i]);
printf("%d\n",maxd);
}
}
}
解法二:树形DP
先将无根树转为有根树,对于每个点来说它所能达到的最远距离,就是以自身为根结点向下 DFS 的最大距离,或者通过自身父节点,再加上父节点不经过自身所能达到的最远距离,这时又有两种情况:
- 父节点最远距离是以它为根结点向下产生(如果最远距离经过这个子节点需要考虑父节点向下的第二远的距离)
- 父节点也是通过它的父节点产生的最远距离
比如下图,对于 2 来说,它的最远距离可能是以它为根的子树(蓝色部分)产生的最大值,或者,通过它父节点不经过本身结点所能达到的最远点(也就是整棵树除蓝色部分,相当于红色部分)。
对于边
(
u
,
v
)
(u,v)
(u,v)
- d p [ u ] [ 0 ] dp[u][0] dp[u][0] 表示以 u u u 为根形成的子树, u u u 所能到达的最远距离
- d p [ u ] [ 1 ] dp[u][1] dp[u][1] 表示以 u u u 为根形成的子树, u u u 所能到达的次远距离
- d p [ u ] [ 2 ] dp[u][2] dp[u][2] 表示通过 u u u 的父节点不经过自身结点,所能达到的最远距离
- 对于 d p [ u ] [ 0 ] dp[u][0] dp[u][0] 和 d p [ u ] [ 1 ] dp[u][1] dp[u][1] 可以通过一遍 dfs 一次得到。
- 对于 d p [ u ] [ 2 ] dp[u][2] dp[u][2] 可以通过已知父节点的最优情况,来推出子节点的最优情况,还是从根向子节点遍历一遍求出答案。
- 第一次从下到上,转移方程为
if(dp[v][0]+e[i].w >= dp[u][0]){
dp[u][1]=dp[u][0];
dp[u][0] = dp[v][0]+e[i].w;
}
else if(dp[v][0]+e[i].w > dp[u][1])
dp[u][1] = dp[v][0]+e[i].w;
要先算出 d p [ v ] [ 0 ] dp[v][0] dp[v][0] 才能知道 d p [ u ] [ 0 ] dp[u][0] dp[u][0] ,所以为从下往上。
- 第二次从上往下,转移方程为
if(dp[u][0] == dp[v][0] + e[i].w)
dp[v][2] = max(dp[u][2],dp[u][1]) + e[i].w;
else
dp[v][2] = max(dp[u][2],dp[u][0]) + e[i].w;
要算 d p [ v ] [ 2 ] dp[v][2] dp[v][2] ,要先算 d p [ u ] [ 0 ] dp[u][0] dp[u][0] ,所以为从上往下。
最后的答案为 m a x ( d p [ u ] [ 0 ] , d p [ u ] [ 2 ] ) max(dp[u][0],dp[u][2]) max(dp[u][0],dp[u][2])
参考 :
https://www.cnblogs.com/dongdong25800/p/11056413.html
https://blog.csdn.net/angon823/article/details/52261423
代码实现
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn=1e4+10;
int n,tot,head[maxn],dp[maxn][3];
struct edge{
int to,next,w;
}e[2*maxn];
void init(int nn){
tot=0;
for(int i=0;i<=nn;i++){
head[i]=-1;
dp[i][0]=dp[i][1]=dp[i][2]=0;
}
}
void add(int x,int y,int z){
e[++tot].to=y;
e[tot].next=head[x];
e[tot].w=z;
head[x]=tot;
}
void dfs1(int x,int f){
for(int i=head[x];i!=-1;i=e[i].next){
int y=e[i].to,w=e[i].w;
if(y==f) continue;
dfs1(y,x);
int dis=dp[y][0]+w;
if(dp[x][0]<dis){
dp[x][1]=dp[x][0];
dp[x][0]=dis;
}
else if(dp[x][1]<dis)
dp[x][1]=dis;
}
}
void dfs2(int x,int f){
for(int i=head[x];i!=-1;i=e[i].next){
int y=e[i].to,w=e[i].w;
if(y==f) continue;
int dis=dp[y][0]+w;
if(dp[x][0]==dis)
dp[y][2]=max(dp[x][2],dp[x][1])+w;
else
dp[y][2]=max(dp[x][2],dp[x][0])+w;
dfs2(y,x);
}
}
int main()
{
while(~scanf("%d",&n)){
init(n);
int a,b;
for(int i=2;i<=n;i++){
scanf("%d%d",&a,&b);
add(i,a,b);
add(a,i,b);
}
dfs1(1,-1);
dfs2(1,-1);
for(int i=1;i<=n;i++){
printf("%d\n",max(dp[i][0],dp[i][2]));
}
}
return 0;
}
POJ3310 Caterpillar
树的直径的一个应用,
首先判断是否为全连通的无环图,然后标记在直径上的点,最后判断不在直径上的点与直径的最短距离是否为 1
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn=110;
int n,m,tot,maxd,vv,head[maxn],fa[maxn],d[maxn],judge[maxn];
struct edge{
int to,next;
}e[6*maxn];
void init(int nn){
tot=0;
for(int i=0;i<=nn;i++){
head[i]=-1;
fa[i]=i;
d[i]=judge[i]=0;
}
}
int find(int x){
return fa[x]!=x?fa[x]=find(fa[x]):x;
}
bool join(int x,int y){
int fx=find(x),fy=find(y);
if(fx!=fy){
fa[fy]=fx;
return true;
}
return false;
}
void add(int x,int y){
e[++tot].to=y;
e[tot].next=head[x];
head[x]=tot;
}
void dfs(int x,int f){
for(int i=head[x];i!=-1;i=e[i].next){
int y=e[i].to;
if(y==f) continue;
d[y]=d[x]+1;
judge[y]=d[y];
if(d[y]>maxd){
maxd=d[y];vv=y;
}
dfs(y,x);
judge[x]=max(judge[x],judge[y]);
}
}
int main()
{
int cnt=0;
while(~scanf("%d",&n)){
cnt++;
if(n==0) break;
init(n);
scanf("%d",&m);
int a,b;
bool flag=true;
for(int i=1;i<=m;i++){
scanf("%d%d",&a,&b);
add(a,b);
add(b,a);
//存在环
if(!join(a,b)) flag=false;
}
int count=0;
for(int i=1;i<=n;i++){
if(fa[i]==i) count++;
}
// 不连通或不满足树形结构
if(count>1||m!=n-1) flag=false;
if(!flag){
printf("Graph %d is not a caterpillar.\n",cnt);
continue;
}
//找直径并标记
maxd=0;
d[1]=0;
dfs(1,-1);
d[vv]=0;
dfs(vv,-1);
//若满足条件则所有的边至少有一端在最长直径上
for(int i=1;i<=n;i++){
bool check=false;
if(judge[i]==maxd) continue;
for(int j=head[i];j!=-1;j=e[j].next){
if(judge[e[j].to]==maxd){
check=true;break;
}
}
if(!check){
flag=false;break;
}
}
if(!flag)
printf("Graph %d is not a caterpillar.\n",cnt);
else
printf("Graph %d is a caterpillar.\n",cnt);
}
return 0;
}
这里用并查集判断一个图是否全连通并且无环路:
并查集初始化;
while(输入边的信息(u,v)){
vis[u]=vis[v]=true;
//判断是否成环
如果连接(u,v)前u与v已在一个集合中则说明存在环
}
int root_num=0;
//判断是否全连通
for(int i=1;i<=n;i++)
if(vis[i]&&fa[i]==i) root_num++;
if(root_num>1) flag=false;
树形背包
洛谷P2014 选课
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=310;
int n,m,tot,head[maxn],cnt[maxn];
int dp[maxn][maxn];
struct edge{
int to,next;
}e[2*maxn];
void init(){
tot=0;
memset(head,-1,sizeof(head));
memset(cnt,0,sizeof(cnt));
memset(dp,0,sizeof(dp));
}
void add(int x,int y){
e[++tot].to=y;
e[tot].next=head[x];
head[x]=tot;
}
void solve(int x,int f){
cnt[x]++;
for(int i=head[x];i!=-1;i=e[i].next){
int y=e[i].to;
if(y==f) continue;
solve(y,x);
cnt[x]+=cnt[y];
for(int j=min(m,cnt[x]);j>=0;j--){
for(int k=0;k<=min(cnt[y],j-1);k++){
dp[x][j]=max(dp[x][j],dp[x][j-k]+dp[y][k]);
}
}
}
}
int main()
{
scanf("%d%d",&n,&m);
init();
int pre;
for(int i=1;i<=n;i++){
scanf("%d%d",&pre,&dp[i][1]);
add(pre,i);
}
m++;
solve(0,-1);
printf("%d\n",dp[0][m]);
return 0;
}
注意最后一层循环 k k k 不能取到 j j j,因为必须要先选先修课(父节点)后面的选取才有意义。
CCF201909-5 城市规划
类似 HDU 4118 的处理方法,只不过这里又额外多加了背包的思想
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=5e4+10;
int n,m,K,a,b,c,tot,head[maxn],cnt[maxn];
bool tag[maxn];
ll dp[maxn][110];
struct edge{
int to,next,w;
}e[2*maxn];
void init(int nn){
tot=0;
for(int i=0;i<=nn;i++){
head[i]=-1;cnt[i]=0;tag[i]=0;
}
memset(dp,0x3f,sizeof(dp));
}
void add(int x,int y,int z){
e[++tot].to=y;
e[tot].next=head[x];
e[tot].w=z;
head[x]=tot;
}
void solve(int x,int f){
dp[x][0]=0;
if(tag[x]){
cnt[x]++;
dp[x][1]=0;
}
for(int i=head[x];i!=-1;i=e[i].next){
int y=e[i].to;
if(y==f) continue;
solve(y,x);
cnt[x]+=cnt[y];
for(int j=min(K,cnt[x]);j>=0;j--){
for(int k=0;k<=min(j,cnt[y]);k++){
dp[x][j]=min(dp[x][j],dp[x][j-k]+dp[y][k]+(ll)k*(K-k)*e[i].w);
}
}
}
}
int main()
{
scanf("%d%d%d",&n,&m,&K);
init(n);
int im;
for(int i=1;i<=m;i++){
scanf("%d",&im);
tag[im]=1;
}
for(int i=1;i<n;i++){
scanf("%d%d%d",&a,&b,&c);
add(a,b,c);
add(b,a,c);
}
solve(1,-1);
printf("%lld\n",dp[1][K]);
return 0;
}