给你一棵有n个节点的二叉树,每个节点有一个权值,对于一棵子树u,将u的子树中的节点权值从大到小排序,令sz[u]为子树u的大小,
则ans[u] = 1 * a[1] + 2 * a[2] + ... + sz[u] * a[sz[u]],其中a[1] >= a[2] >= ... >= a[u]。求所有节点的答案。
对每个节点建立权值线段树,dfs整棵树,线段树合并
ans[rt] = ans[ls[rt]] + ans[rs[rt]] + size[ls[rt]] * w[rs[rt]],w表示某权值区间的权值和,size表示某权值区间内点的个数。
#include <iostream> #include <cstdio> #include <algorithm> #include <cstring> using namespace std; const int maxn = 1e5 + 10; const int maxnode = 2e6 + 10; struct edge { int to, next; }e[maxn << 1]; int head[maxn], ecnt; void edge_init() { ecnt = 0; memset(head, -1, sizeof(head)); } void add(int u, int v) { e[ecnt].to = v; e[ecnt].next = head[u]; head[u] = ecnt++; } int a[maxn], b[maxn]; int root[maxn]; int sz[maxnode], ls[maxnode], rs[maxnode]; long long ans[maxnode], sum[maxnode]; int tot, m; int mergeleaf(int u, int v) { sz[u] += sz[v]; sum[u] += sum[v]; ans[u] = sum[u] / (long long)sz[u] * (long long) sz[u] * (long long) (sz[u] + 1LL) / 2LL; return u; } int merge(int u, int v, int l, int r) { if (!u || !v) return u | v; if (l == r) return mergeleaf(u, v); int mid = (l + r) >> 1; ls[u] = merge(ls[u], ls[v], l, mid); rs[u] = merge(rs[u], rs[v], mid + 1, r); sz[u] = sz[ls[u]] + sz[rs[u]]; sum[u] = sum[ls[u]] + sum[rs[u]]; ans[u] = ans[ls[u]] + ans[rs[u]] + sum[ls[u]] * (long long) sz[rs[u]]; return u; } void update(int x, int &rt, int l, int r) { if (!rt) rt = ++tot; sum[rt] = ans[rt] = (long long) b[x]; sz[rt] = 1; if (l == r) return; int mid = (l + r) >> 1; if (x <= mid) update(x, ls[rt], l, mid); else update(x, rs[rt], mid + 1, r); } void dfs(int u, int fa) { update(a[u], root[u], 1, m); for (int i = head[u]; i != -1; i = e[i].next) { int v = e[i].to; if (v == fa) continue; dfs(v, u); root[u] = merge(root[u], root[v], 1, m); } } int main() { int T, n; scanf("%d", &T); while (T--) { edge_init(); scanf("%d", &n); for (int i = 1; i <= n; ++i) scanf("%d", a + i), b[i] = a[i]; sort(b + 1, b + 1 + n); m = unique(b + 1, b + 1 + n) - (b + 1); for (int i = 1; i <= n; ++i) a[i] = lower_bound(b + 1, b + 1 + m, a[i]) - b; for (int u, v ,i = 1; i < n; ++i) { scanf("%d%d", &u, &v); add(u, v); add(v, u); } tot = 0; memset(root, 0, sizeof(root)); dfs(1, 0); for (int i = 1; i <= n; ++i) printf("%lld ", ans[root[i]]); puts(""); for (int i = 1; i <= tot; ++i) ls[i] = rs[i] = sum[i] = ans[i] = sz[i] = 0; } }