题面
解法
细节处理较多的树形dp
- 不妨假设最终的两个带权重心为 x , y x,y x,y,显然,存在一条边将这棵树分成两部分,每一部分都离那一个重心最近
- 所以,我们可以枚举那一条边,然后求出它分成的两棵树的带权重心
- 考虑一下这个算法的正确性,尽管对于两棵树的带权重心我们枚举的那条边的分割方式并不合法,但是在我们枚举真正合法的那一条边的时候,这两个点依然是不会变的,所以我们这么做并不会影响算法的正确性,即最优解一定会被统计到
- 现在的问题是,我们已经使用了 O ( n ) O(n) O(n)的时间来枚举割边,如何快速找到这两个带权重心
- 观察一下数据规模,可以发现这棵树的深度并不大
- 就讲一下一棵树的带权重心怎么求,另一棵树是完全类似的
- 假设当前我们找到的带权重心为 x x x,考虑它的所有儿子 y y y,如果存在某一个 y y y使得 w ( x ) > w ( y ) w(x)>w(y) w(x)>w(y),那么我们找到 w ( y ) w(y) w(y)最小的哪一个,然后把当前的带权重心变成 y y y, w ( x ) w(x) w(x)是点 x x x作为这棵树的带权重心的权值,这显然是正确的
- 那么,问题又转化为如何求 w ( x ) w(x) w(x)
- 不妨令 f i f_i fi表示点 i i i的子树中所有点 y y y到 i i i的距离 × y ×y ×y的权值之和, g i g_i gi表示 i i i子树外的,这两个值可以分别在两遍dfs中求出
- 然后求 w ( x ) w(x) w(x)就比较容易了,只要分类考虑一下子树内的情况和子树外的情况即可,这个可以 O ( 1 ) O(1) O(1)计算,尽管在计算的时候需要把细节想想清楚
- 时间复杂度: O ( n d ) O(nd) O(nd)
【注意事项】
- 在计算
f
(
x
)
,
g
(
x
)
,
w
(
x
)
f(x),g(x),w(x)
f(x),g(x),w(x)的时候可能细节较多,需要当心,
我这种蒟蒻就调了好久
代码
#include <bits/stdc++.h>
#define ll long long
#define inf 1ll << 60
#define N 50010
using namespace std;
template <typename T> void chkmax(T &x, T y) {x = x > y ? x : y;}
template <typename T> void chkmin(T &x, T y) {x = x > y ? y : x;}
template <typename T> void read(T &x) {
x = 0; int f = 1; char c = getchar();
while (!isdigit(c)) {if (c == '-') f = -1; c = getchar();}
while (isdigit(c)) x = x * 10 + c - '0', c = getchar(); x *= f;
}
struct Edge {int next, num;} e[N * 3];
int cnt, d[N], a[N], p[N][21];
ll ans, f[N], g[N], sum[N];
void add(int x, int y) {
e[++cnt] = (Edge) {e[x].next, y};
e[x].next = cnt;
}
void dfs1(int x, int fa) {
sum[x] = a[x], d[x] = d[fa] + 1;
for (int i = 1; i <= 20; i++)
p[x][i] = p[p[x][i - 1]][i - 1];
for (int q = e[x].next; q; q = e[q].next) {
int k = e[q].num; if (k == fa) continue;
p[k][0] = x; dfs1(k, x);
f[x] += f[k] + sum[k], sum[x] += sum[k];
}
}
void dfs2(int x, int fa) {
for (int p = e[x].next; p; p = e[p].next) {
int k = e[p].num; if (k == fa) continue;
g[k] = g[x] + f[x] - f[k] - sum[k] + sum[1] - sum[k];
dfs2(k, x);
}
}
int lca(int x, int y) {
if (d[x] < d[y]) swap(x, y);
for (int i = 20; ~i; i--)
if (d[p[x][i]] >= d[y]) x = p[x][i];
if (x == y) return x;
for (int i = 20; ~i; i--)
if (p[x][i] != p[y][i]) x = p[x][i], y = p[y][i];
return p[x][0];
}
ll calc(int rt, int tx, int key) {
int x = rt; ll ret;
if (key == 1) {
int ty = lca(x, tx);
ret = f[x] + g[x] - f[tx] - sum[tx] * (d[x] + d[tx] - 2 * d[ty]);
} else ret = f[x] + g[x] - g[rt] - (sum[1] - sum[rt]) * (d[x] - d[rt]);
while (1) {
ll mn = inf, mni;
for (int p = e[x].next; p; p = e[p].next) {
int y = e[p].num; if (y == tx) continue;
if (key == 0) {
ll tmp = f[y] + g[y] - g[rt] - (sum[1] - sum[rt]) * (d[y] - d[rt]);
if (tmp < mn) mn = tmp, mni = y;
} else {
int ty = lca(y, tx);
ll tmp = f[y] + g[y] - f[tx] - sum[tx] * (d[y] + d[tx] - 2 * d[ty]);
if (tmp < mn) mn = tmp, mni = y;
}
}
if (mn >= ret) break;
ret = mn, x = mni;
}
return ret;
}
void dfs(int x, int fa) {
for (int p = e[x].next; p; p = e[p].next) {
int k = e[p].num; if (k == fa) continue;
chkmin(ans, calc(k, x, 0) + calc(1, k, 1));
dfs(k, x);
}
}
int main() {
int n; read(n); cnt = n;
for (int i = 1; i < n; i++) {
int x, y; read(x), read(y);
add(x, y), add(y, x);
}
for (int i = 1; i <= n; i++) read(a[i]);
dfs1(1, 0), dfs2(1, 0);
ans = inf; dfs(1, 0); cout << ans << "\n";
return 0;
}