题目
思路
不妨给每条边定向,小指向大。那么,由于 a > 0 a>0 a>0 ,为了使一个点的 a a a 被统计次数最少,肯定要两条边接在一起。所以,假如有 x x x 条边指向这个点,有 y y y 条边从这个点指向别人,那么贡献是 max ( x , y ) ⋅ a \max(x,y)\cdot a max(x,y)⋅a
但是 b b b 相等时,边可以任意选一个方向。于是可以树形 d p \tt dp dp ,用 f ( x , 0 / 1 ) f(x,0/1) f(x,0/1) 表示父边的指向情况。然而怎么决策呢?
考虑到贡献只与数量有关,与具体是哪一个无关,我们 贪心地选择变化量小的 即可。
复杂度 O ( n log n ) \mathcal O(n\log n) O(nlogn) 。
代码
#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
inline int readint(){
int a = 0; char c = getchar(), f = 1;
for(; c<'0'||c>'9'; c=getchar())
if(c == '-') f = -f;
for(; '0'<=c&&c<='9'; c=getchar())
a = (a<<3)+(a<<1)+(c^48);
return a*f;
}
const int MaxN = 200005;
const long long infty = (1ll<<61)-1;
int b[MaxN], a[MaxN], n;
struct Edge{
int to, nxt;
};
Edge e[MaxN<<1];
int head[MaxN], cntEdge;
void addEdge(int a,int b){
e[cntEdge].to = b;
e[cntEdge].nxt = head[a];
head[a] = cntEdge ++;
}
long long dp[MaxN][2]; // 0:in 1:out
long long pq[MaxN]; // priority_queue
void dfs(int x,int pre){
// printf("dfs %d\n",x);
int cnt[2] = {0,0};
long long tmp = 0;
int tot = 0; // FAKE queue
for(int i=head[x]; ~i; i=e[i].nxt){
if(e[i].to == pre) continue;
dfs(e[i].to,x); // continue working
}
for(int i=head[x],opt; ~i; i=e[i].nxt){
if(e[i].to == pre) continue;
if(b[e[i].to] != b[x])
opt = b[e[i].to] > b[x];
else{
opt = 0; // assume first
pq[tot ++] = dp[e[i].to][0]
- dp[e[i].to][1];
}
++ cnt[opt]; // dirrected
tmp += dp[e[i].to][opt^1];
}
dp[x][0] = dp[x][1] = infty;
sort(pq,pq+tot), pq[tot] = 0;
// printf("tot(%d) = %d\n",x,tot);
for(int i=0; i<=tot; ++i){
dp[x][0] = min(dp[x][0],tmp+1ll*
max(cnt[0]+1,cnt[1])*a[x]);
dp[x][1] = min(dp[x][1],tmp+1ll*
max(cnt[0],cnt[1]+1)*a[x]);
tmp += pq[i]; // change value
-- cnt[0], ++ cnt[1];
}
}
int main(){
n = readint();
for(int i=1; i<=n; ++i)
a[i] = readint();
for(int i=1; i<=n; ++i){
b[i] = readint();
head[i] = -1;
}
for(int i=1,s,y; i<n; ++i){
s = readint(), y = readint();
addEdge(s,y), addEdge(y,s);
}
dfs(1,-1);
int cnt[2] = {0,0};
long long tmp = 0;
int tot = 0; // FAKE queue
int x = 1; // to simplify code
for(int i=head[x],opt; ~i; i=e[i].nxt){
if(b[e[i].to] != b[x])
opt = b[e[i].to] > b[x];
else{
opt = 0; // assume first
pq[tot ++] = dp[e[i].to][0]
- dp[e[i].to][1];
}
++ cnt[opt]; // dirrected
tmp += dp[e[i].to][opt^1];
}
long long ans = infty;
sort(pq,pq+tot), pq[tot] = 0;
for(int i=0; i<=tot; ++i){
ans = min(ans,tmp+1ll*a[x]*
max(cnt[0],cnt[1]));
tmp += pq[i]; // change value
-- cnt[0], ++ cnt[1];
}
printf("%lld\n",ans);
return 0;
}