签 到 题 签到题 签到题
正 解 部 分 \color{red}{正解部分} 正解部分
这道题 每个子树看成一个子问题, 求出每个子树的答案, 然后往上合并得到总答案 .
设当前节点有 2 2 2 个子树, 权值和 和 节点数量 分别是 s u m 1 , s i z e 1 , s u m 2 , s i z e 2 sum_1, size_1, sum_2, size_2 sum1,size1,sum2,size2,子树内的答案为 a n s 1 , a n s 2 ans_1, ans_2 ans1,ans2
则先往
1
1
1 儿子走对答案的贡献为:
a
n
s
1
+
s
u
m
1
+
a
n
s
2
+
(
s
i
z
e
1
+
1
)
×
s
u
m
2
ans_1+ sum_1+ ans_2 + (size_1+1)\times sum_2
ans1+sum1+ans2+(size1+1)×sum2
往
2
2
2 儿子走对答案的贡献为:
a
n
s
2
+
s
u
m
2
+
a
n
s
1
+
(
s
i
z
e
2
+
1
)
×
s
u
m
1
ans_2+ sum_2+ ans_1 + (size_2+1)\times sum_1
ans2+sum2+ans1+(size2+1)×sum1,
当 走 1 1 1儿子 比 走 2 2 2儿子 更优时,
a n s 1 + s u m 1 + a n s 2 + ( s i z e 1 + 1 ) × s u m 2 < a n s 2 + s u m 2 + a n s 1 + ( s i z e 2 + 1 ) × s u m 1 ans_1+ sum_1+ ans_2 + (size_1+1)\times sum_2 < ans_2+ sum_2+ ans_1 + (size_2+1)\times sum_1 ans1+sum1+ans2+(size1+1)×sum2<ans2+sum2+ans1+(size2+1)×sum1
化简得 s i z e 1 × s u m 2 < s i z e 2 × s u m 1 size_1 \times sum_2 < size_2\times sum_1 size1×sum2<size2×sum1 .
所以以 s i z e x × s u m y size_x \times sum_y sizex×sumy 从小到大排序后, 从小到大按顺序 d f s dfs dfs 即可实现答案最优 .
现在已经解决了当根固定时的答案, 考虑如何计算 所有节点作为根的 最优值,
可以想到 先求出以 1 1 1 为根 的答案, 然后进行 换根,
现在已经计算出了
a
n
s
x
ans_x
ansx, 且要将 根的位置 从
x
→
y
x \rightarrow y
x→y, 要求
y
y
y 为根的答案,
首先观察 树的信息 哪里发生了变化,
- 以 y y y为根 的子树 从 x x x 的子树中移除掉了, s i z e x − = s i z e y , s u m x − = s u m y size_x -=size_y,sum_x-=sum_y sizex−=sizey,sumx−=sumy
- 以 x x x为根 的子树 成为了 y y y 的新子树, s i z e y + = s i z e x , s u m y + = s u m x size_y += size_x, sum_y += sum_x sizey+=sizex,sumy+=sumx .
对
a
n
s
x
ans_x
ansx 的影响为
a
n
s
x
−
=
a
n
s
y
+
s
i
z
e
y
前
子
树
×
s
u
m
y
+
s
i
z
e
y
×
s
u
m
y
后
子
树
ans_x -= ans_y + size_{y前子树}\times sum_y + size_y \times sum_{y后子树}
ansx−=ansy+sizey前子树×sumy+sizey×sumy后子树,
其中
a
n
s
y
ans_y
ansy 在往下递归的时候使用子树信息计算即可 .
实 现 部 分 \color{red}{实现部分} 实现部分
#include<bits/stdc++.h>
#define reg register
#define pb push_back
typedef long long ll;
int read(){
char c;
int s = 0, flag = 1;
while((c=getchar()) && !isdigit(c))
if(c == '-'){ flag = -1, c = getchar(); break ; }
while(isdigit(c)) s = s*10 + c-'0', c = getchar();
return s * flag;
}
const int maxn = 200005;
int N;
int num0;
int A[maxn];
int size[maxn];
int head[maxn];
ll tot;
ll Ans;
ll sum[maxn];
ll ans[maxn];
struct Edge{ int nxt, to; } edge[maxn << 1];
void Add(int from, int to){
edge[++ num0] = (Edge){ head[from], to };
head[from] = num0;
}
bool cmp(int a, int b){ return size[a]*sum[b] < size[b]*sum[a]; }
void DFS_1(int k, int fa){
std::vector <int> B;
sum[k] = A[k], size[k] = 1;
for(reg int i = head[k]; i; i = edge[i].nxt){
int to = edge[i].to;
if(to == fa) continue ; B.pb(to);
DFS_1(to, k);
sum[k] += sum[to], size[k] += size[to];
}
std::sort(B.begin(), B.end(), cmp);
ans[k] = A[k]; ll last = 1;
for(reg int i = 0; i < B.size(); i ++){
int to = B[i];
ans[k] += ans[to] + last * sum[to], last += size[to];
}
}
void DFS_2(int k, int fa){
std::vector <int> B;
for(reg int i = head[k]; i; i = edge[i].nxt) B.pb(edge[i].to);
std::sort(B.begin(), B.end(), cmp);
ans[k] = A[k]; ll last = 1;
for(reg int i = 0; i < B.size(); i ++){
int to = B[i];
ans[k] += ans[to] + last * sum[to];
last += size[to];
}
Ans = std::min(Ans, ans[k]);
last = 1; ll suf = tot - A[k];
for(reg int i = 0; i < B.size(); i ++){
int to = B[i]; suf -= sum[to];
if(to != fa){
ll t1 = ans[k], t2 = ans[to];
ans[k] -= ans[to] + last*sum[to] + size[to]*suf;
size[k] -= size[to], sum[k] -= sum[to];
size[to] += size[k], sum[to] += sum[k];
DFS_2(to, k);
size[to] -= size[k], sum[to] -= sum[k];
size[k] += size[to], sum[k] += sum[to];
ans[k] = t1, ans[to] = t2;
}
last += size[to];
}
}
int main(){
N = read();
for(reg int i = 1; i < N; i ++){ int u = read(), v = read(); Add(u, v), Add(v, u); }
for(reg int i = 1; i <= N; i ++) A[i] = read(), tot += A[i];
DFS_1(1, 1);
Ans = ans[1]; DFS_2(1, 1);
printf("%lld\n", Ans);
return 0;
}