链接:https://www.acwing.com/problem/content/4623/
给定一个 n 个节点的树,节点编号为 1∼n。
请你从中选择一个简单路径(不能包含重复节点或重复的边),并沿所选路径来一场旅行,更具体的说,就是从所选路径的一个端点沿路径前往另一个端点。
注意,所选简单路径可以只由一个节点组成。
旅行需要花费能量。
初始时,你的能量为 0。
在旅行过程中:
每经过一个节点(包括起点和终点),就可以得到该节点的能量,其中节点 i 包含的能量为 wi。
每经过一条边 (u,v),就需要消耗一定的能量 c。
你设计的旅行路线应满足:在经过任何一条边之前,你的现有能量都不能少于该边所需消耗的能量(否则,将无法顺利通过该边)。
在满足条件 1 的前提下,旅行结束时,剩余的能量尽可能大。
请计算并输出剩余能量的最大可能值。输入格式
第一行包含整数 n。第二行包含 n 个整数 w1,w2,…,wn。
接下来 n−1 行,每行包含三个整数 u,v,c,表示存在一条边 (u,v),经过它所需的能量为 c。
保证给定图是一棵树。
输出格式
一个整数,表示剩余能量的最大可能值。数据范围
前 3 个测试点满足 1≤n≤5。
所有测试点满足 1≤n≤3×10e5,0≤wi≤10e9,1≤u,v≤n,u≠v,1≤c≤10e9。输入样例1:
3 1 3 3 1 2 2 1 3 2
输出样例1:
3
输入样例2:
5 6 3 2 5 0 1 2 10 2 3 3 2 4 1 1 5 1
输出样例2:
7
该题用树状dp求解。
所剩能量最大的路径有两种情况,一种是从节点i的子节点到节点i为止,另一种是从节点i的子节点到节点i,然后又到节点i的子节点。
第一种情况求出从节点i的子节点到节点i中所有路径中最大的一条即可。
第二种情况则是求出最大的两条相加。
将两种情况所剩的能量进行比较,取最大的。即为经过节点i所剩的最大能量。
那么设f[i] 表示从i的所有子节点到i这个节点所剩的最大值,sonj是它的第j个子节点,wj是它到子节点sonj所需要的花费,vali是当前节点的权值。
所以状态转移方程 f[i] = max{f[son1]-w1, f[son2]-w2, ..., f[sonj]-wj} + vali
虽然该题有 现有能量如果小于要消耗的能量则不能通过 的限定条件,但是不影响,我们只要把每个点的最大值的初始值设为0即可,这样根据方程的计算,当现有能量小于消耗能量时,其值为负数,由于最大值至少是0,所以不会考虑这一条路径,就相当于没通过一样。
这样状态转移方程则变为 f[i] = max{0, f[son1]-w1, f[son2]-w2, ..., f[sonj]-wj} + vali
代码如下:
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long LL;
const int N = 3e5 + 5;
int Next[2*N], w[2*N], e[2*N], head[N], val[N];
int idex = 0;
LL ans = 0, f[N] {0}; //ans用来记录最大值
void add(int x, int y, int z) //存图
{
e[idex] = y;
Next[idex] = head[x];
w[idex] = z;
head[x] = idex ++;
}
void dfs(int x, int u) //第一个参数是当前节点,第二个参数是它的父节点
{
LL maxv1 = 0, maxv2 = 0; //最大值和第二大值,初始最大值都为0,因为不要考虑负数情况
for(int i=head[x]; i!=-1; i=Next[i])
{
if(e[i]==u) continue; //当它的下一个节点等于父节点时,跳过
dfs(e[i], x);
if(f[e[i]]-w[i]>=maxv1) //当最大值可以更新 或出现和最大值相等的值时,更新最大值和第二大值
{
maxv2 = maxv1;
maxv1 = f[e[i]] - w[i];
}
else if(f[e[i]]-w[i]>maxv2) //当值小于最大值,但大于第二大值时,更新第二大值
maxv2 = f[e[i]] - w[i];
}
f[x] = maxv1 + val[x]; //更新到当前节点所剩最大值
ans = max(ans, maxv1+maxv2+val[x]); //求经过该节点的两种情况路径,所剩的最大的值
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
int n;
cin >> n;
memset(head, -1, sizeof(head));
for(int i=1; i<=n; i++) cin >> val[i];
for(int i=1; i<n; i++)
{
int x, y, z;
cin >> x >> y >> z;
add(x, y, z);
add(y, x, z);
}
dfs(1, -1);
cout << ans << endl;
return 0;
}