题目地址:
http://www.lydsy.com/JudgeOnline/problem.php?id=4297
= =
题意:
给定一棵有n个点,m个叶子节点的树,其中m个叶子节点分别为1到m号点,每个叶子节点有一个权值r[i]。你需要给剩下n-m个点各指定一个权值,使得树上相邻两个点的权值差的绝对值之和最小。
思路:
QAQ,看了Claris的代码,又自己想了想,但还是有点迷迷糊糊。后来jxt看了这题,说了他的思路,自己才理解这题。然后下面说的是jxt的思路
首先题目给了m个权值确定的叶子节点,那么答案的确定可以通过叶子节点,自底向上地完成。讨论一个情况:当前结点为y,x是y的子节点,x有 k 个已经确定了权值的儿子,目前只考虑x的取值。首先,如果x的取值非常大(或者非常小,不过由于所有权值都是正整数,这种情况不一定会出现),每当 val[x] 变化1,对答案的影响就是 k ,当 val[x] 渐渐变小,变得比一部分子节点权值大(记其为 kbig ),比一部分子节点权值小(记其为 ksmall )的时候, val[x] 的变化给答案带来的变化就是 |kbig−ksmall| 。
上述情况可以得出两个结论:
- 当 |kbig−ksmall| 最小时,x的子树贡献的答案最小,为最优,此时 val[x] 的取值明显是一个范围,在这个范围里, |kbig−ksmall| 达到了最小值。
- 当 val[x] 大小加减1时, |val[y]−val[x]| 的取值变化始终是1,但当 val[x] 没有处于使 |kbig−ksmall| 最小的范围时,显然: val[x] 大小加减1对 |kbig−ksmall| 变化的影响始终大于等于1,则 val[x] 与x子树取值对答案的贡献大于 val[x] 与 val[y] 对答案的贡献。如果x及其子树没有达到最优,那么当x达到最优时的答案一定比当前答案优秀。(即一个树达到最优的条件是其所有子树达到最优)
综上,为了解题,我们只需要从低往上,确定当前节点使其子树节点最优的取值范围,然后用这个取值范围递推出其父节点的取值范围,就可以求出答案。不需要重新建图。
代码:
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <vector>
#include <cstdlib>
using namespace std;
#define PB push_back
#define MS(x, y) memset(x, y, sizeof(x))
typedef pair<int, int> P;
typedef long long LL;
const int MAXN = 5e5 + 5;
const LL INF = 1000000000000000000LL;
int n, m;
LL ans;
int l[MAXN], r[MAXN];
int fa[MAXN];
P a[MAXN << 1];
vector<int> edges[MAXN];
void dfs(int u, int fa) {
// cout << "dfs ing ..." << endl;
if (edges[u].size() == 1) return ;
for (int i = edges[u].size() - 1; i >= 0; --i) {
if (edges[u][i] == fa) continue;
dfs(edges[u][i], u);
}
int cnt = 0, v;
LL mn = INF, sum = 0, now;
int fut = 0, pst = 0;
// 确定u使其子树达到最优的取值范围
LL fut_sum = 0, pst_sum = 0;
for (int i = edges[u].size() - 1; i >= 0; --i) {
v = edges[u][i];
if (v == fa) continue;
a[cnt++] = P(l[v], 0);
a[cnt++] = P(r[v], 1);
++fut;
fut_sum += l[v];
}
sort(a, a + cnt);
for (int i = 0; i < cnt; ++i) {
if (a[i].second) {
++pst;
pst_sum += a[i].first;
} else {
--fut;
fut_sum -= a[i].first;
}
now = fut_sum - fut * a[i].first + pst * a[i].first - pst_sum;
if (now < mn) {
mn = now;
l[u] = a[i].first;
}
if (now == mn) r[u] = a[i].first;
}
ans += mn;
}
int main() {
while (~scanf("%d%d", &n, &m)) {
int u, v, lim;
ans = 0;
head = tail = 0;
MS(fa, 0);
MS(deg, 0);
MS(used, false);
for (int i = 1; i <= n; ++i) edges[i].clear();
for (int i = 1; i < n; ++i) {
scanf("%d%d", &u, &v);
edges[u].PB(v);
edges[v].PB(u);
}
for (int i = 1; i <= m; ++i) {
scanf("%d", l + i);
r[i] = l[i];
}
if (n == m) {
for (u = 1; u <= n; ++u) {
for (int i = edges[u].size() - 1; i >= 0; --i) {
v = edges[u][i];
ans += abs(l[u] - l[v]);
}
}
printf("%I64d\n", ans / 2);
continue;
}
dfs(n, 0);
printf("%I64d\n", ans);
}
}