链接:https://www.nowcoder.com/acm/contest/140/H
来源:牛客网
时间限制:C/C++ 2秒,其他语言4秒
空间限制:C/C++ 262144K,其他语言524288K
64bit IO Format: %lld
题目描述
White Cloud has a tree with n nodes.The root is a node with number 1. Each node has a value.
White Rabbit wants to travel in the tree 3 times. In Each travel it will go through a path in the tree.
White Rabbit can't pass a node more than one time during the 3 travels. It wants to know the maximum sum value of all nodes it passes through.
输入描述:
The first line of input contains an integer n(3 <= n <= 400001) In the next line there are n integers in range [0,1000000] denoting the value of each node. For the next n-1 lines, each line contains two integers denoting the edge of this tree.
输出描述:
Print one integer denoting the answer.
示例1
输入
复制
13 10 10 10 10 10 1 10 10 10 1 10 10 10 1 2 2 3 3 4 4 5 2 6 6 7 7 8 7 9 6 10 10 11 11 12 11 13
输出
复制
110
树形dp。
令dp(i,j,k)表示在i的子树上存在j条选择的链,并且有k条是接在i这个节点上的。
这里为了操作简单,默认链都是竖直链,也即是说接在一个节点上的链最多2条
之后根据各种情况转移即可,具体转移可以看代码。
#include<bits/stdc++.h>
#define mp make_pair
#define fir first
#define se second
#define ll long long
#define pb push_back
using namespace std;
const int maxn=4e5+10;
const ll mod=1e9+7;
const int maxm=1e6+10;
const double eps=1e-7;
const ll inf=(ll)1e13;
struct Edge{
int u,v,next;
}edge[maxm];
int head[maxn];
int tot=0;
void init(){
memset(head,-1,sizeof(head));
tot=0;
}
void addedge(int u,int v){
edge[tot]=Edge{u,v,head[u]};
head[u]=tot++;
}
int n;
ll a[maxn];
ll dp[maxn][8][4];
//dp(i,j,k)表示i子树上有j条链,i节点上有k条链的结果
void dfs(int u,int fa){
dp[u][0][0]=0;
//啥都没有,肯定是0了
dp[u][1][1]=a[u];
//只有自己一个节点作为一条链的答案
for (int i=head[u];~i;i=edge[i].next){
int v=edge[i].v;
if (v==fa) continue;
dfs(v,u);
for (int j=4;j>=0;j--){
for (int k=j-1;k>=0;k--){
dp[u][j][2]=max(dp[u][j][2],dp[u][k][1]+dp[v][j-k][1]);
//其他子树上提供了k条链,并且有一条在u上,那么剩下的j-k条由v提供,并提供一条u-v的链
dp[u][j][2]=max(dp[u][j][2],dp[u][k][2]+max(dp[v][j-k][0],max(dp[v][j-k+1][2],dp[v][j-k][1])));
//其他子树上提供了k条链,并且有两条在u上,那么剩下的j-k条由v提供,不提供u节点到v的链
//这时候v节点的dp值取0条连在v上,一条连在v上,两条连在v上的dp的最大值
//取两条的时候,数目要+1
dp[u][j][1]=max(dp[u][j][1],dp[u][k][0]+dp[v][j-k][1]+a[u]);
//其他子树上提供了k条链,并且有一条在u上,那么剩下的j-k条由v提供,并提供一条u-v的链
dp[u][j][1]=max(dp[u][j][1],dp[u][k][1]+max(dp[v][j-k][0],max(dp[v][j-k+1][2],dp[v][j-k][1])));
//其他子树上提供了k条链,并且有一条在u上,那么剩下的j-k条由v提供,不提供u节点到v的链
//这时候v节点的dp值取0条连在v上,一条连在v上,两条连在v上的dp的最大值
//取两条的时候,数目要+1
dp[u][j][0]=max(dp[u][j][0],dp[u][k][0]+max(dp[v][j-k][0],max(dp[v][j-k+1][2],dp[v][j-k][1])));
//其他子树上提供了k条链,并且没有在u上,那么剩下的j-k条由v提供,不提供u节点到v的链
//这时候v节点的dp值取0条连在v上,一条连在v上,两条连在v上的dp的最大值
//取两条的时候,数目要+1
}
}
}
}
int main(){
scanf("%d",&n);
for (int i=1;i<=n;i++){
scanf("%lld",&a[i]);
}
init();
for (int i=1;i<n;i++){
int u,v;
scanf("%d %d",&u,&v);
addedge(u,v);
addedge(v,u);
}
for (int i=1;i<=n;i++){
for (int j=0;j<=5;j++){
for (int k=0;k<=3;k++){
dp[i][j][k]=-1*inf;
}
}
}
dfs(1,0);
ll ans=max(dp[1][3][0],max(dp[1][3][1],dp[1][4][2]));
printf("%lld\n",ans);
return 0;
}