题意:给定一棵树,求出树上的一点,使得树上的全部点到该点的距离之和最小。
思路:暴力显然是O(N^2)等死对吧。
我们首先将无根树转化为有根树,然后一边dfs求出f[i],size[i].
f[i]表示以i为根的子树中全部的点到i的距离之和,size[i]表示以i为根的子树的点数。
以下開始脑洞大开:
如今对于我们一開始的那个root,我们已经知道了答案。问题就是怎样高速的推知别的点作为根时的答案。
我们又一次进行一次dfs,当找到x时,我们用dp[fa[x]]+padis[x]*size[fa[x]]更新答案。
我们记录一下当前的dp[x],以及size[x].
每找到一个儿子son,向下dfs时,我们令dp[x]=dp[fa[x]]+size[fa[x]]*padis[x]+dp[x]-dp[son]-size[son]*padis[son],size[x]=size[fa[x]]+size[x]-size[son],然后再向下dfs.
不要问我为什么。。。
我的代码用的是更加脑洞大开的方法。。。
Code:
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
#define N 100010
int head[N], next[N << 1], end[N << 1], len[N << 1];
void addedge(int a, int b, int _len) {
static int q = 1;
len[q] = _len;
end[q] = b;
next[q] = head[a];
head[a] = q++;
}
int num[N];
long long dp[N];
int pa[N], padis[N], size[N];
void dfs(int x, int fa) {
size[x] = num[x];
for(int j = head[x]; j; j = next[j])
if (end[j] != fa)
pa[end[j]] = x, padis[end[j]] = len[j], dfs(end[j], x);
for(int j = head[x]; j; j = next[j])
if (end[j] != fa)
dp[x] += dp[end[j]] + (long long)size[end[j]] * len[j], size[x] += size[end[j]];
}
long long res = 1LL << 60;
int presize[N], sufsize[N], addsize[N], sav[N], top;
long long pre[N], suf[N], add[N];
void work(int x) {
long long ans = add[x] + (long long)addsize[x] * padis[x] + dp[x];
if (ans < res)
res = ans;
register int i, j;
top = 0;
for(j = head[x]; j; j = next[j])
if (end[j] != pa[x])
sav[++top] = end[j];
presize[0] = pre[0] = 0, sufsize[top + 1] = suf[top + 1] = 0;
for(i = 1; i <= top; ++i)
presize[i] = presize[i - 1] + size[sav[i]], pre[i] = pre[i - 1] + dp[sav[i]] + (long long)size[sav[i]] * padis[sav[i]];
for(i = top; i >= 1; --i)
sufsize[i] = sufsize[i + 1] + size[sav[i]], suf[i] = suf[i + 1] + dp[sav[i]] + (long long)size[sav[i]] * padis[sav[i]];
for(i = 1; i <= top; ++i) {
addsize[sav[i]] = addsize[x] + num[x] + presize[i - 1] + sufsize[i + 1];
add[sav[i]] = add[x] + (long long)addsize[x] * padis[x] + pre[i - 1] + suf[i + 1];
}
for(j = head[x]; j; j = next[j])
if (end[j] != pa[x])
work(end[j]);
}
int main() {
int n;
scanf("%d", &n);
register int i, j;
for(i = 1; i <= n; ++i)
scanf("%d", &num[i]);
int a, b, x;
for(i = 1; i < n; ++i) {
scanf("%d%d%d", &a, &b, &x);
addedge(a, b, x);
addedge(b, a, x);
}
dfs(1, -1);
work(1);
printf("%lld", res);
return 0;
}