Description
给定一棵有 n n n 个节点的树,每个节点有一个权值 w i w_i wi,你需要从前往后一个一个地把点安放进数组 a a a 中。 a a a 需要满足若点 a i a_i ai 为 a j a_j aj 的父亲,则 i i i 必须小于 j j j 。一种安放方案的权值为 ∑ i = 1 n w a [ i ] × i \sum_{i=1}^n w_{a[i]} \times i ∑i=1nwa[i]×i,请找出权值最小的方案并输出最小权值。
1 ≤ n ≤ 1000 1 \leq n \leq 1000 1≤n≤1000。
Solution
如果没有父节点排在节点之前的限制,就可以按照权值从大到小排序的顺序选。对于一个权值极大的点,如果没有限制的话我们越早放这个点越好,但是有限制的话,就必须将这个点与他的父亲绑定在一起,一旦他的父亲放了,就紧挨着他的父亲放进去。
先不考虑如何绑定节点。如果在保持这个绑定关系同时,其他节点不作限制,我们应该如何排列呢? 我们假设绑定在一起的节点是
a
a
a 和
b
b
b 。现有另外一个节点
x
x
x,我们观察两种排列
x
a
b
xab
xab,
a
b
x
abx
abx 对最终的计算结果有什么影响,假设第一个数的位置在序列中是
i
i
i ,这两个排列带来的权值分别是:
v
a
l
1
=
w
x
×
i
+
w
a
×
(
i
+
1
)
+
w
b
×
(
i
+
2
)
val_1 = w_x \times i+w_a \times (i+1) + w_b \times (i+2)
val1=wx×i+wa×(i+1)+wb×(i+2)
v a l 2 = w a × i + w b × ( i + 1 ) + w x × ( i + 2 ) val_2 = w_a \times i+w_b \times (i+1) + w_x \times (i+2) val2=wa×i+wb×(i+1)+wx×(i+2)
显然谁的 v a l val val 小谁就选谁。
v a l 2 − v a l 1 = 2 × w x − ( w a + w b ) val_2 - val_1 = 2 \times w_x - (w_a + w _b) val2−val1=2×wx−(wa+wb)
如果结果 > 0 >0 >0,那么就选前者,否则选后者。这相当于比较 2 × w x 2 \times w_x 2×wx 与 w a + w b w_a + w_b wa+wb 谁大,左右同时除二,比较的是 w x 1 \frac{w_x}{1} 1wx 与 w a + w b 2 \frac{w_a + w_b}{2} 2wa+wb。
发现这个东西好像是平均数?也就是说,设一个序列 w 1 w_1 w1 有 n n n 个节点,另一个序列 w 2 w_2 w2 的序列有 m m m 个节点,我们可以比较 ∑ i = 1 n w 1 i n \frac{ \sum_{i=1}^n w_{1_i}}{n} n∑i=1nw1i 与 ∑ i = 1 m w 2 i m \frac{ \sum_{i=1}^m w_{2_i}}{m} m∑i=1mw2i,小的先选。我们将 w i w_i wi 的平均值称为 c i c_i ci。
考虑如何绑定节点的问题,一开始,每个序列只有一个点,它们的 c i c_i ci 就是 w i w_i wi,每次操作把 c i c_i ci 最大的点合并到他的父节点并记录贡献,一直合并到根节点即可。可以用优先队列实现。
Code
#include <bits/stdc++.h>
using namespace std;
const int N = 1000 + 5;
struct node {
int par, sz, val;
}a[N];
int n, rt, fa[N], vis[N];
priority_queue<pair<double, int> >q;
int find(int x) {
if (fa[x] == x) return x;
return fa[x] = find(fa[x]);
}
int main() {
while (~scanf("%d%d",&n, &rt) && n && rt){
int ans = 0;
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i].val);
a[i].sz = 1, ans += a[i].val;
if (i != rt) q.push(make_pair(a[i].val, i));
}
for (int i = 1; i < n; i++) {
int x, val; scanf("%d%d", &x, &val);
a[val].par = x;
}
for (int i = 1; i <= n; i++) fa[i] = i, vis[i] = 0;
while (!q.empty()) {
int x = q.top().second; q.pop();
if (vis[x]) continue;
vis[x] = 1;
int par = find(a[x].par);
ans += a[par].sz * a[x].val;
a[par].val += a[x].val;
a[par].sz += a[x].sz;
if (par != rt) q.push(make_pair(1.0 * a[par].val / a[par].sz, par));
fa[x] = par;
}
printf("%d\n", ans);
}
return 0;
}