题目描述
输入
输出
一行一个整数表示答案
样例
输入1
5
40 10 30 50 20
2 3 2 3 1
1 2
1 3
2 4
2 5
输出1
160
输入2
5
1000000 1 1 1 1
1000000 1 1 1 1
1 2
1 3
1 4
1 5
输出2
4000004
输入3
10
510916 760492 684704 32545 484888 933975 116895 77095 127679 989957
402815 705067 705067 705067 623759 103335 749243 138306 138306 844737
1 2
3 2
4 3
1 5
6 4
6 7
8 7
8 9
9 10
输出3
6390572
说明/提示
分析
考虑贪心。在经过二十分钟的冥思苦想后,我们可以得到一个结论:贪心做不了。。。(至少我没想出来)
于是我们考虑树形DP。定义
d
p
[
i
]
[
0
/
1
]
dp[i][0/1]
dp[i][0/1]为以
i
i
i 为根能得到的最大价值。其中
0
/
1
0/1
0/1 表示轨迹是向上还是向下(
0
0
0 表示向上递增,
1
1
1 表示向上递减)。只要这个状态定义出来了,这道题的难度就大大降低了。如果
a
[
s
o
n
]
<
a
[
i
]
a[son] < a[i]
a[son]<a[i] 还是
a
[
s
o
n
]
>
a
[
i
]
a[son] > a[i]
a[son]>a[i],状态都很好转移,那
a
[
s
o
n
]
=
=
a
[
i
]
a[son] == a[i]
a[son]==a[i] 的时候到底是判断他们递减还是递增呢?我们可以在这里暂时将他们设为向上递增,中间不断改变,比一个最小值即可。那中间怎么改变呢?不是有
2
k
2^k
2k 种方案吗,不妨将
d
p
[
i
]
[
1
]
−
d
p
[
i
]
[
0
]
dp[i][1] - dp[i][0]
dp[i][1]−dp[i][0] 放入一个
v
e
c
t
o
r
vector
vector 里,再排序,每次取其中最小值,再统计一共递增和递减的个数比最大值乘这个点的权值就行了。
这里需要注意,若这个点不是根,它的父节点和它还有一条边,统计个数时,需要在递增或递减的个数中加
1
1
1。
d
p
dp
dp 式子太难打了,看代码吧。。这题还是挺好理解的。。
代码
#include <cstdio>
#include <algorithm>
#include <climits>
#include <cstring>
#include <cmath>
#include <vector>
#define LL long long
using namespace std;
const int MAXN = 2 * 1e5 + 5;
const LL lof = 9 * 1e18;
int n, a[MAXN], b[MAXN];
int Head[MAXN], Ver[MAXN << 1], Next[MAXN << 1], tot;
LL dp[MAXN][2];
vector <LL> v[MAXN];
void add(int x, int y) {
Ver[++ tot] = y;
Next[tot] = Head[x]; Head[x] = tot;
}
void dfs(int x, int fa) {
int tota = 0, totb = 0;
LL Sum = 0;
for(int i = Head[x]; i; i = Next[i]) {
int Y = Ver[i];
if(Y == fa) continue;
dfs(Y, x);
if(b[Y] < b[x]) {
tota ++;
Sum += dp[Y][0];
}
else if(b[Y] > b[x]) {
totb ++;
Sum += dp[Y][1];
}
else {
tota ++;
Sum += dp[Y][0];
v[x].push_back(dp[Y][1] - dp[Y][0]);
}
}
sort(v[x].begin(), v[x].end());
for(int i = 0; i <= v[x].size(); i ++) {
int f = 0;
if(x != 1) f = 1;
dp[x][0] = min(dp[x][0], Sum + (LL)max(tota, totb + f) * a[x]);
dp[x][1] = min(dp[x][1], Sum + (LL)max(tota + f, totb) * a[x]);
tota --; totb ++;
if(i == v[x].size()) continue;
Sum += v[x][i];
}
}
int main() {
int x, y;
scanf("%d", &n);
for(int i = 1; i <= n; i ++) {
scanf("%d", &a[i]);
dp[i][0] = dp[i][1] = lof;
}
for(int i = 1; i <= n; i ++) {
scanf("%d", &b[i]);
}
for(int i = 1; i < n; i ++) {
scanf("%d%d", &x, &y);
add(x, y); add(y, x);
}
dfs(1, -1);
printf("%lld", min(dp[1][0], dp[1][1]));
return 0;
}