树形dp就是在一棵树上跑动态规划,通常用到dfs等方法实现。
这篇博客就是记录我刷的树形dp,刷题路线参考了y总的课程以及聚聚的博客:【DP_树形DP专辑】
一、常规树形dp
1、求树的中心
树的中心是树上的一个节点,这个节点在距离其他节点的最长路径最短。
这个非常非常基础,同时在很多较复杂的题目中需要用到这个方法解决降低复杂度。
思路分析:
任意选一节点作为根节点
对于任意一个节点,距离它最远的节点有两种情况:
1⃣️最远的节点经过父节点
2⃣️最远的节点不经过父节点
继续观察第1⃣️种情况,可以发现最大距离可能由父节点的(A.最长路径)或者(B.次长路径)更新(如果父节点的最长路径经过该节点则由次长路径更新)。
代码实现:
//寻找树的中心
const int maxn=1e4+7;
struct ed{
int to;
int w;
};
vector<ed> mp[maxn];
struct pos{
int d1,d2;
int v1,v2;
//分别为最大/次大距离,最大/次大距离经过的子节点
}dp[maxn];
int dfs1(int u,int f){
//先用子节点更新父节点
//更新每个节点最大和次大路径以及对应经过的子节点
//函数返回值为子节点的最远距离
if(mp[u].size()==1){
return 0;
}
dp[u].d1=-INF;
dp[u].d2=-INF;
for(int i=0;i<mp[u].size();i++){
int to=mp[u][i].to;
if(to!=f){
int d=dfs1(to,u)+mp[u][i].w;
//dist=max(d,dist);
if(d>=dp[u].d1){
dp[u].d2=dp[u].d1;
dp[u].v2=dp[u].v1;
dp[u].d1=d;
dp[u].v1=to;
}
else if(d>dp[u].d2){
dp[u].d2=d;
dp[u].v2=to;
}
}
}
if(dp[u].d1==-INF){
dp[u].d1=0;
dp[u].d2=0;
}// 判断是否为叶子节点
return dp[u].d1;
}
//用父节点更新子节点
int up[maxn];
void dfs2(int u,int f){
//由父节点更新子节点
for(int i=0;i<mp[u].size();i++){
int to=mp[u][i].to;
if(to==f) continue;
if(dp[u].v1==to){
up[to]=max(up[u],dp[u].d2)+mp[u][i].w;
}
else{
up[to]=max(up[u],dp[u].d1)+mp[u][i].w;
}
dfs2(to,u);
}
}
int main(){
int n;
cin>>n;
for(int i=1;i<=n-1;i++){
int a,b,c;
cin>>a>>b>>c;
mp[a].pb({b,c});
mp[b].pb({a,c});
}
dfs1(1,-1);
dfs2(1,-1);
int res=INF;
for(int i=1;i<=n;i++){
res=min(res,max(up[i],dp[i].d1));
}
cout<<res<<endl;
}
2、树的最长路径(含负边权)
思路分析:
如何划分状态?如何更新ans?
选任意点作为树根(因为是无向树)。
计算每一个以该节点作为最高点的路径,以此更新ans。
那么如何计算以该节点作为最高点的路径呢?————找到除了经过其父节点的最大和次大距离,按顺序dfs即可。
代码实现:
const int maxn=1e4+7;
struct ed{
int to;
int w;
};
vector<ed> mp[maxn];
int ans=0;//记录答案
int dfs(int u,int fa){
//标记father,防止再次回到father节点
//dfs函数return的是该节点走到底的最长路径
int d1=0,d2=0;//记录最大和次大节点
int dist=0;
//cout<<"u="<<u<<endl;
for(int i=0;i<mp[u].size();i++){
int to=mp[u][i].to;
if(to!=fa){
int d=dfs(to,u)+mp[u][i].w;
dist=max(dist,d);
if(d>=d1){
d2=d1;
d1=dist;
}
else if(d>=d2){
d2=d;
}
}
}
ans=max(ans,d1+d2);
return dist;
}
int main(){
int n;
cin>>n;
for(int i=1;i<=n-1;i++){
int a,b,c;
cin>>a>>b>>c;
mp[a].pb({b,c});
mp[b].pb({a,c});
}
dfs(1,-1);
cout<<ans<<endl;
}
二、树上背包问题
1、二叉苹果树思路分析:
其实可以看成一个分组背包问题,所有的物品呈递归依赖关系,只有选择了父节点才能选择子节点。
同时对于任意一个节点,其每个子节点能达到的状态都是唯一的,相当于以其每个子节点为依据分组做分组背包。
const int maxn=110;
int N,Q;
struct ed{
int to;
int w;
};
vector<ed> mp[maxn];
int dp[maxn][maxn];//dp[i][j]表示如果取节点i,容量为j下获得的最大价值
void dfs(int u,int f){
for(int i=0;i<mp[u].size();i++){
int to=mp[u][i].to;
int w=mp[u][i].w;
if(to==f) continue;
dfs(to,u);
for(int j=Q;j>=0;j--){
for(int k=0;k<j;k++){
dp[u][j]=max(dp[u][j],dp[u][j-k-1]+dp[to][k]+w);
}
}
}
}
int main(){
cin>>N>>Q;
for(int i=1;i<=N-1;i++){
int a,b,c;
cin>>a>>b>>c;
mp[a].pb({b,c});
mp[b].pb({a,c});
}
dfs(1,-1);
cout<<dp[1][Q]<<endl;
}
2、有依赖的背包问题
与上题类似,只不过这次节点的“体积”不一定是1,在每条边上加一种权值即可。
const int maxn=110;
int N,Q;
struct ed{
int to;
int v;
int w;
};
vector<ed> mp[maxn];
int dp[maxn][maxn];//dp[i][j]表示如果取节点i,容量为j下获得的最大价值
void dfs(int u,int f){
//cout<<"u="<<u<<endl;
for(int i=0;i<mp[u].size();i++){
int to=mp[u][i].to;
int w=mp[u][i].w;
if(to==f) continue;
dfs(to,u);//在更新父节点之前要先更新子节点
// cout<<"u="<<u<<endl;
// cout<<"to="<<to<<endl;
for(int j=Q;j>=0;j--){
//每个字子节点相关的物品都是一个分组,只能挑一件,所以从大到小更新
//cout<<"j="<<j<<endl;
for(int k=0;k<j;k++){
//枚举子节点权值,选择了这个子节点相关的“物品”之后该“物品”的整体体积要加上边权记录的值,因为dp数组记录的体积没有包含父节点连向子节点的那条边
if(j-k-mp[u][i].v>=0){
dp[u][j]=max(dp[u][j],dp[u][j-k-mp[u][i].v]+dp[to][k]+w);
//cout<<"dp="<<dp[u][j]<<endl;
}
}
//cout<<endl;
}
//cout<<endl;
}
//cout<<endl;
}
int main(){
cin>>N>>Q;
for(int i=1;i<=N;i++){
int v,w,p;
cin>>v>>w>>p;
if(p==-1){
p=0;
}
mp[p].pb({i,v,w});
mp[i].pb({p,v,w});
}
dfs(0,-1);
cout<<dp[0][Q]<<endl;
}
三、树上删边或节点问题
随便揪个例题吧。
如果想到了二分这个题思路就非常清晰了。
但是要注意一个坑点(WA了一上午),详细在代码注释中。
//二分答案?
const int maxn=1e3+7;
struct ed{
int to;
int w;
};
vector<ed> mp[maxn];
int dp[maxn];//记录切断这个枝条上所有叶节点的最低花费
void dfs(int u,int f,int upper){
//同时判断是否可行
int p=0;
dp[u]=0;
for(int i=0;i<mp[u].size();i++){
int to=mp[u][i].to;
if(to==f) continue;
dfs(to,u,upper);
p=1;
if(mp[u][i].w<=upper){
dp[u]+=min(dp[to],mp[u][i].w);
}
else{
dp[u]+=dp[to];
}
}
if(!p)dp[u]=1e6+7;//因为会累加,如果太大了会爆int,呜呜呜呜然后直接就WA了呜呜呜呜
}
int main(){
int n,m;
while(scanf("%d%d",&n,&m)!=EOF){
if(n==0&&m==0) break;
for(int i=1;i<=n;i++){
mp[i].clear();
}
int ma=0;
for(int i=1;i<=n-1;i++){
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
mp[a].pb({b,c});
mp[b].pb({a,c});
ma=max(ma,c);
}
int l=0,r=ma;
int ans=-1;
while(l<=r){
int mid=(l+r)>>1;
dfs(1,-1,mid);
if(dp[1]<=m){
ans=mid;
r=mid-1;
}
else{
l=mid+1;
}
}
printf("%d\n",ans);
}
}