题目大意:有一棵苹果树,每个节点有一个苹果,吃掉 u u u 点的苹果能获得 a u a_u au点 HP,经过第 i i i 条边需要消耗 w i w_i wi HP,在原地等待一秒可以获得 1 1 1 HP,每条边只能经过两次,问从1号节点出发吃掉所有苹果最少需要等待多少秒。
分析:首先在某个点一次性把所需要的HP等够是显然正确的,根据每条边只能经过两次,解的形式一定是先吃掉某棵子树,再回到根,再吃其它子树,最优解是一个吃的顺序的问题,而吃掉子树是一个子问题,考虑用树形 dp。
设 d p [ u ] dp[u] dp[u] 表示以 u u u 为根的答案, s u m [ u ] sum[u] sum[u] 表示吃完 u u u 的子树后 HP 的变化。
s
u
m
[
u
]
sum[u]
sum[u] 转移方程:
s
u
m
[
u
]
=
∑
v
∈
s
o
n
(
u
)
(
s
u
m
[
v
]
−
2
∗
w
v
)
sum[u] = \sum_{v\in son(u)}(sum[v] - 2 * w_v)
sum[u]=∑v∈son(u)(sum[v]−2∗wv)
d
p
[
u
]
dp[u]
dp[u] 的转移要考虑子树的遍历顺序。
对于第 i i i 棵子树,进入到这棵子树需要 a i = d p [ i ] + w i a_i = dp[i] + w_i ai=dp[i]+wi,从他回到 u u u 后 得到的 HP 为 b i = d p [ i ] + s u m [ i ] − w i b_i=dp[i]+sum[i]-w_i bi=dp[i]+sum[i]−wi
对遍历顺序进行分类讨论,简化问题,考虑只有两棵子树时的遍历顺序:
先遍历 b i ≥ a i b_i \geq a_i bi≥ai 的子树更优:
设有
a
i
≤
b
i
,
a
j
>
b
j
a_i \leq b_i,a_j > b_j
ai≤bi,aj>bj,分别讨论先
i
i
i 和先
j
j
j 的答案:
1、若
a
i
≥
a
j
a_i \geq a_j
ai≥aj,则两种情况答案分别为:
a
i
a_i
ai,
a
j
+
a
i
−
b
j
a_j + a_i - b_j
aj+ai−bj
2、若
a
i
<
a
j
a_i < a_j
ai<aj,则两种情况答案分别为:
a
i
+
m
a
x
(
a
j
−
b
i
,
0
)
a_i + max(a_j - b_i,0)
ai+max(aj−bi,0) ,
a
j
+
m
a
x
(
a
i
−
b
j
,
0
a_j + max(a_i - b_j,0
aj+max(ai−bj,0,画一下发现先
i
i
i 答案至多为
a
j
a_j
aj,而先
j
j
j 答案至少为
a
j
a_j
aj
若均满足
a
<
b
a<b
a<b:
设有
a
i
≤
b
i
,
a
j
≤
b
j
a_i \leq b_i, a_j \leq b_j
ai≤bi,aj≤bj,设
a
i
<
a
j
a_i < a_j
ai<aj
两种情况答案分别为:
a
i
+
m
a
x
(
a
j
−
b
i
,
0
)
a_i + max(a_j - b_i,0)
ai+max(aj−bi,0),
a
j
a_j
aj,由于
b
i
≥
a
i
b_i \geq a_i
bi≥ai,显然有
a
i
+
m
a
x
(
a
j
−
b
i
,
0
)
≤
a
j
a_i + max(a_j - b_i,0) \leq a_j
ai+max(aj−bi,0)≤aj
若均满足
a
>
b
a>b
a>b:
设有
a
i
>
b
i
,
a
j
>
b
j
a_i > b_i, a_j > b_j
ai>bi,aj>bj
两种情况分别为:
a
i
+
m
a
x
(
a
j
−
b
i
,
0
)
a_i + max(a_j - b_i,0)
ai+max(aj−bi,0),
a
j
+
m
a
x
(
a
i
−
b
j
,
0
)
a_j + max(a_i - b_j,0)
aj+max(ai−bj,0),在纸上画一下可以发现先遍历
b
b
b 更大的子树更优。
代码:
#include<bits/stdc++.h>
using namespace std;
#define pii pair<int,int>
#define fir first
#define sec second
typedef long long ll;
const int maxn = 1e5 + 10;
vector<pii> g[maxn];
int t, n, a[maxn];
ll dp[maxn], sum[maxn];
struct node {
ll in, out;
node() {}
node(ll i,ll j) {
in = i, out = j;
}
};
bool cmp1(node x,node y) {
return x.in < y.in;
}
bool cmp2(node x,node y) {
return x.out > y.out;
}
void dfs(int u,int fa) {
dp[u] = 0, sum[u] = a[u];
vector<node> x, y;
for (auto it : g[u]) {
int v = it.fir, w = it.sec;
if (v == fa) continue;
dfs(v,u);
sum[u] += sum[v] - 2ll * w;
node t;
if (w <= dp[v] + sum[v])
t = node(w + dp[v],dp[v] + sum[v] - w);
else
t = node(2ll * w - sum[v],0);
if (t.in <= t.out)
x.push_back(t);
else
y.push_back(t);
}
sort(x.begin(),x.end(),cmp1);
sort(y.begin(),y.end(),cmp2);
ll cur = a[u];
for (auto it : x) {
if (it.in > cur) {
dp[u] += it.in - cur;
cur = it.in;
}
cur = cur - it.in + it.out;
}
for (auto it : y) {
if (it.in > cur) {
dp[u] += it.in - cur;
cur = it.in;
}
cur = cur - it.in + it.out;
}
}
int main() {
scanf("%d",&t);
while (t--) {
scanf("%d",&n);
for (int i = 1; i <= n; i++)
g[i].clear();
for (int i = 1; i <= n; i++)
scanf("%d",&a[i]);
for (int i = 1; i < n; i++) {
int u, v, w;
scanf("%d%d%d",&u,&v,&w);
g[u].push_back(pii(v,w));
g[v].push_back(pii(u,w));
}
dfs(1,0);
printf("%lld\n",dp[1]);
}
return 0;
}