题目大意:
给定一棵有n个节点的树,给定每个点的点权,每个边的边权是1。你可以任选一个点作为V点,使得
∑
i
=
1
n
d
i
s
t
(
i
,
v
)
∗
a
[
i
]
\sum_{i=1}^n{dist(i,v) * a[i]}
∑i=1ndist(i,v)∗a[i]最大。输出这个最大值。(n <= 2 * 1e^5)。
注意到这里的n比较大,所以肯定不能对每个点进行搜索。
首先枚举每个点肯定是必须的。我们考虑如何快速的求解出一个点的
∑
i
=
1
n
d
i
s
t
(
i
,
v
)
∗
a
[
i
]
\sum_{i=1}^n{dist(i,v) * a[i]}
∑i=1ndist(i,v)∗a[i]。
我的思路是:
先把1作为树的根节点,计算出
∑
i
=
1
n
d
i
s
t
(
i
,
1
)
∗
a
[
i
]
\sum_{i=1}^n{dist(i,1) * a[i]}
∑i=1ndist(i,1)∗a[i]。然后再枚举2-n这些剩余的点,通过
∑
i
=
1
n
d
i
s
t
(
i
,
1
)
∗
a
[
i
]
\sum_{i=1}^n{dist(i,1) * a[i]}
∑i=1ndist(i,1)∗a[i],计算出其它点的
∑
i
=
1
n
d
i
s
t
(
i
,
v
)
∗
a
[
i
]
\sum_{i=1}^n{dist(i,v) * a[i]}
∑i=1ndist(i,v)∗a[i]。每个点的初始值都是
∑
i
=
1
n
d
i
s
t
(
i
,
1
)
∗
a
[
i
]
\sum_{i=1}^n{dist(i,1) * a[i]}
∑i=1ndist(i,1)∗a[i]。
dep[i]代表点i在树内的深度,另点1的深度为0。
sumN[i]代表点i的i的子树的点权和(包含i)。
sum[k]代表点i的
∑
i
=
1
n
d
i
s
t
(
i
,
k
)
∗
a
[
i
]
\sum_{i=1}^n{dist(i,k) * a[i]}
∑i=1ndist(i,k)∗a[i]。
我们以点7为例,根据关系来看,我们可以分为在点5子树内的点和在点7子树外的点。
首先点7到自身距离为0,首先先减去a[7]*dep[7]。
求子树内的点到点7的值。因为现在sum[7]中存放是点7的子树的点到1的值,所以我们需要减去一部分,减去的是(sumN[7] - a[7]) * dep[7];//子树节点的点权和 乘以 (7到1的距离)。
对于子树外的点,当前7的层数为2,因为我们递归的过程是1->5->7,所以我们考虑怎么通过这两个点计算出子树外的所有点到点7的距离。我们可以发现其他点都是1结点和5结点的不包含7的子树的点。
对于5结点,我们不包含7的子树为6,8。我们想要更新6,8到7的值,则需要先减去a[6]*dep[5](不会再走的路径),再加上a[6] * (dep[7] - dep[5])(新增路径)。结点5到结点7的值也可以用这种方法更新。能否一次更新5,6,8到7的值呢?
我们可以直接加上(sumN[5] - sumN[7]) * (dep[7] - dep[5]) - (sumN[5] - sumN[7]) * (dep[5])。即完成了5结点的不包含7的子树所有结点到结点7的值的更新。结点1的不包含7的子树的结点的值的更新,也是如此。
对于处于n层的结点k,其递归路径为
k
1
−
>
k
2
.
.
.
.
.
−
>
k
n
k_1->k_2.....->k_n
k1−>k2.....−>kn,对于任意一层m,则其不包含
k
n
k_n
kn子树的结点的值的更新为(sumN[
k
m
k_m
km] - sumN[
k
m
+
1
k_{m +1}
km+1]) * (dep[
k
n
k_n
kn] - dep[
k
m
k_{m}
km]) - (sumN[
k
n
k_n
kn] - sumN[
k
m
k_{m}
km]) * (dep[
k
m
k_{m}
km])
化简一下即:
(sumN[
k
m
k_m
km] - sumN[
k
m
+
1
k_{m +1}
km+1]) * dep[
k
n
k_n
kn] - 2 * (sumN[
k
m
k_m
km] - sumN[
k
m
+
1
k_{m +1}
km+1]) *dep[
k
m
k_{m}
km]。
所以我们再进去下一层时需要计算两个值:
sumN[
k
m
k_m
km] - sumN[
k
m
+
1
k_{m +1}
km+1]
2 * (sumN[
k
m
k_m
km] - sumN[
k
m
+
1
k_{m +1}
km+1]) *dep[
k
m
k_{m}
km]。
在递归过程中把两个值累加入两个变量即可。
#include<cstdio>
#define N 1200000
#define LL long long
int n, m, k;
int To[N], Fr[N], Ne[N];
bool v[N];
int dep[N], fa[N];
int a[N];
LL sum[N], he, maxa;
LL sumN[N];
void dfs(int x, int f, int depp){
fa[x] = f;
dep[x] = depp;
for (int i = Fr[x]; i != 0; i = Ne[i]){
if(To[i] == f) continue;
dfs(To[i], x, depp + 1);
sumN[x] = sumN[To[i]] + sumN[x];
}
}
void dfs1(int x, LL sum1, LL sum2){
sum[x] = he - (LL)a[x] * dep[x];
sum[x] = sum[x] - (sumN[x] - a[x]) * dep[x];
sum[x] = sum[x] + sum1 * dep[x] - sum2 * 2;
if(sum[x] > maxa) maxa = sum[x];
for (int i = Fr[x]; i != 0; i = Ne[i]){
if(To[i] == fa[x]) continue;
dfs1(To[i], sum1 + (sumN[x] - sumN[To[i]]), sum2 + (LL)dep[x] * (sumN[x] - sumN[To[i]]));
}
}
int main(){
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]), sumN[i] = a[i];
for (int i = 1; i < n; i++){
int x, y;
scanf("%d%d", &x, &y);
k++;To[k] = y; Ne[k] = Fr[x]; Fr[x] = k;
k++;To[k] = x; Ne[k] = Fr[y]; Fr[y] = k;
}
dfs(1, 0, 0);
for (int i = 1; i <= n; i++) he = he + (LL)dep[i] * a[i];
maxa = he;
sumN[0] = sumN[1];
dfs1(1, 0, 0);
printf("%I64d", maxa);
}