宝藏(树形DP)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_38601996/article/details/78051270

题目描述

这里写图片描述

输入

这里写图片描述

输出

这里写图片描述

样例输入

6
1 2 1
2 3 3
3 4 36
3 6 13
3 5 2
6 8 9 10 13 1

样例输出

30
29
28
10
30
16

提示

这里写图片描述

solution:

对于50%的数据,O(n²)的树形DP是非常好想的。
f[u][0]表示从u点向下走,不回来的最大收益
f[u][1]表示从u点向下走,要回到u点的最大收益
f[u][1]的转移比较简单

for(int i=head[u];i;i=Next[i]){
        int v=vet[i];
        if(v!=pre){
            dfs(v,u);
            if(f[v][1]-w[i]-w[i]>0)
                f[u][1]+=f[v][1]-w[i]-w[i];
        }
    }

然后是f[i][0]的转移

for(int i=head[u];i;i=Next[i]){
        int v=vet[i];
        if(v!=pre){
            if(f[v][1]-w[i]-w[i]>0)
                f[u][0]=max(f[u][0],f[u][1]+w[i]-f[v][1]+f[v][0]);
            else
                f[u][0]=max(f[u][0],f[u][1]-w[i]+f[v][0]);
        }
    }

枚举根,每次都跑一遍dfs,这样就有50分

void dfs(int u,int pre){
    f[u][0]=f[u][1]=a[u];
    for(int i=head[u];i;i=Next[i]){
        int v=vet[i];
        if(v!=pre){
            dfs(v,u);
            if(f[v][1]-w[i]-w[i]>0)
                f[u][1]+=f[v][1]-w[i]-w[i];
        }
    }
    for(int i=head[u];i;i=Next[i]){
        int v=vet[i];
        if(v!=pre){
            if(f[v][1]-w[i]-w[i]>0)
                f[u][0]=max(f[u][0],f[u][1]+w[i]-f[v][1]+f[v][0]);
            else
                f[u][0]=max(f[u][0],f[u][1]-w[i]+f[v][0]);
        }
    }
}

然后,剩下的50分呢?
画一棵树,模拟一下,会发现每次换根并没有改变所有的f值。
如果前后两个根是相邻的,只有原来的根和新的根的f值会被改变。
所以,只需要用DFS枚举根,每次让根沿着树上的边移动,每次换根重算新旧两个根的f值,然后递归结束后再改回来。
时间复杂度O(n)

code

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstdlib>
#include<cstring>
using namespace std;
int n,vet[1000000],Next[1000000],head[500000],en;
long long w[1000000],f[500000][2],a[500000],ans[500000];
void addedge(int u,int v,long long val){
    vet[++en]=v;
    w[en]=val;
    Next[en]=head[u];
    head[u]=en;
}
void dfs(int u,int pre){
    f[u][0]=f[u][1]=a[u];
    for(int i=head[u];i;i=Next[i]){
        int v=vet[i];
        if(v!=pre){
            dfs(v,u);
            if(f[v][1]-w[i]-w[i]>0)
                f[u][1]+=f[v][1]-w[i]-w[i];
        }
    }
    for(int i=head[u];i;i=Next[i]){
        int v=vet[i];
        if(v!=pre){
            if(f[v][1]-w[i]-w[i]>0)
                f[u][0]=max(f[u][0],f[u][1]+w[i]-f[v][1]+f[v][0]);
            else
                f[u][0]=max(f[u][0],f[u][1]-w[i]+f[v][0]);
        }
    }
}
void DFS(int u,int pre){
    ans[u]=f[u][0];
    long long tmp0=f[u][0],tmp1=f[u][1],t0,t1;
    for(int i=head[u];i;i=Next[i]){
        int v=vet[i];
        t0=f[v][0];
        t1=f[v][1];
        if(v!=pre){
            f[u][0]=f[u][1]=a[u];
            for(int j=head[u];j;j=Next[j]){
                int vv=vet[j];
                if(vv!=v){
                    if(f[vv][1]-w[j]-w[j]>0)
                        f[u][1]+=f[vv][1]-w[j]-w[j];
                }
            }
            for(int j=head[u];j;j=Next[j]){
                int vv=vet[j];
                if(vv!=v){
                    if(f[vv][1]-w[j]-w[j]>0)
                        f[u][0]=max(f[u][0],f[u][1]+w[j]-f[vv][1]+f[vv][0]);
                    else
                        f[u][0]=max(f[u][0],f[u][1]-w[j]+f[vv][0]);
                }
            }

            f[v][0]=f[v][1]=a[v];
            for(int j=head[v];j;j=Next[j]){
                int vv=vet[j];
                if(f[vv][1]-w[j]-w[j]>0)
                    f[v][1]+=f[vv][1]-w[j]-w[j];
            }
            for(int j=head[v];j;j=Next[j]){
                int vv=vet[j];
                if(f[vv][1]-w[j]-w[j]>0)
                    f[v][0]=max(f[v][0],f[v][1]+w[j]-f[vv][1]+f[vv][0]);
                else
                    f[v][0]=max(f[v][0],f[v][1]-w[j]+f[vv][0]);
            }
            DFS(v,u);
            f[u][0]=tmp0;
            f[u][1]=tmp1;
            f[v][0]=t0;
            f[v][1]=t1;
        }
    }
}
int main(){
    //freopen("treasure.in","r",stdin);
    //freopen("treasure.out","w",stdout);
    scanf("%d",&n);
    for(int i=1;i<n;i++){
        int x,y;
        long long z;
        scanf("%d%d%lld",&x,&y,&z);
        addedge(x,y,z);
        addedge(y,x,z);
    }
    for(int i=1;i<=n;i++)
        scanf("%lld",&a[i]);
    dfs(1,0);
    DFS(1,0);
    for(int i=1;i<=n;i++)
        printf("%lld\n",ans[i]);
    return 0;
}
阅读更多
想对作者说点什么?

博主推荐

换一批

没有更多推荐了,返回首页