题意:
树上每个点都有一个对应的颜色。
对于树上的每条路径,其 value 记为这条路径上出现过的所有颜色数的总和。
要求对所有路径的 value 值求和。
思路:
/*
首先再力荐一篇blog,写得很完整具体,十分感谢原Po
http://blog.csdn.net/calabash_boy/article/details/76166110
*/
首先来想另一道老题,
对于树上的每条路径,其 value 记为该路径上出现过的所有颜色的编号的总和,要求所有路径的 value 的和。
这道题只需考虑每一个点对整体的 value 的贡献,即考虑哪些路径经过了该点。
经过该点的路径,其两个端点必然一个在以该点为根的子树中 (size[u] 种选择),另一个不在子树中 (n - size[u] 种选择),点 u 的贡献即为 size[u] * (n - size[u]) * val[u].
只需要树上 dp 一下记录以每个点为根的子树的 size 即可。
但是这道题不一样,因为每个点在经过它的路径里不一定会有贡献。
事实上,每个点在那些 其颜色出现过至少一次 的路径中 发挥了一次贡献;反过来想,即为它在其颜色未出现过的路径中 没有贡献。
现假设某种颜色在所有路径中都出现过,那么它对答案的贡献即为总路径数,n * (n - 1) / 2.
再减去该颜色未出现过的路径的总数,即为该颜色的贡献。
至于怎么求该颜色未出现过的路径的总数,显然需要将树分块,具体的做法,还请参见我上面推荐的那篇博文
(再推荐一次)http://blog.csdn.net/calabash_boy/article/details/76166110
利用的是 DFS序,方法很巧妙
然后我们来回顾一下上面提到的两道题,事实上两道题求的都是某个 部分对整体的贡献,而显然不能像题目字面说的一样去做。
第一题中看的部分是每个点,第二题中看的部分是每种颜色,都保证了不会有重复计算,并且都可计算。
第二题中另一个重要思想是补集思想,所谓的“正难则反”,还是要强化印象。
以及将树分块的方法,也的确十分巧妙,值得学习。
AC代码如下:
#include <cstdio>
#include <vector>
#include <iostream>
#include <algorithm>
#include <cstring>
#define maxn 200010
using namespace std;
typedef long long LL;
LL n;
vector<int> col[maxn];
int kas, cnt, tot, ne[maxn], a[maxn * 2];
struct Edge {
int to, ne;
Edge(int a = 0, int b = 0) : to(a), ne(b) {}
}edge[maxn * 2];
struct Node {
int l, r, sz, fa;
Node(int a = -1, int b = -1, int c = 1) : l(a), r(b), sz(c) {}
}node[maxn];
void add(int x, int y) {
Edge e(y, ne[x]);
edge[tot] = e;
ne[x] = tot++;
}
void dfs(int u, int fa) {
node[u].fa = fa;
a[cnt++] = u;
node[u].sz = 1;
for (int i = ne[u]; i != -1; i = edge[i].ne) {
Edge e = edge[i]; int v = e.to;
if (v == fa) continue;
dfs(v, u);
node[u].sz += node[v].sz;
}
a[cnt++] = u;
}
bool cmp(int u, int v) { return node[u].l < node[v].l; }
void work() {
memset(ne, -1, sizeof(ne));
tot = 0;
cnt = 1;
for (int i = 0; i <= n; ++i) {
col[i].clear();
node[i].l = node[i].r = -1;
}
for (int i = 1; i <= n; ++i) {
int temp;
scanf("%d", &temp);
col[temp].push_back(i);
}
add(0, 1); add(1, 0);
for (int i = 0; i < n - 1; ++i) {
int x, y;
scanf("%d%d", &x, &y);
add(x, y);
add(y, x);
}
dfs(0, -1);
for (int i = 0; i < cnt; ++i) {
if (node[a[i]].l == -1) node[a[i]].l = i;
else node[a[i]].r = i;
}
// printf("\n");
// for (int i = 0; i <= n; ++i) {
// printf("%d %d %d\n", i, node[i].l, node[i].r);
// }
LL ans = (n * (n - 1) >> 1) * n;
for (int i = 1; i <= n; ++i) {
if (col[i].empty()) {
ans -= n * (n - 1) >> 1;
continue;
}
col[i].push_back(0);
sort(col[i].begin(), col[i].end(), cmp);
// printf("%d\n", col[i].size());
for (auto& u : col[i]) {
for (int ii = ne[u]; ii != -1; ii = edge[ii].ne) {
Edge e = edge[ii]; int v = e.to;
if (v == node[u].fa) continue;
LL sz = node[v].sz;
node[n + 1].l = node[v].l;
while (true) {
auto it = lower_bound(col[i].begin(), col[i].end(), n + 1, cmp);
if (it == col[i].end() || node[*it].l > node[v].r) break;
sz -= node[*it].sz;
node[n + 1].l = node[*it].r;
}
// printf("%d %d %d\n", u, v, sz);
ans -= sz * (sz - 1) >> 1;
}
}
}
printf("Case #%d: %lld\n", ++kas, ans);
}
int main() {
while (scanf("%lld", &n) != EOF) work();
return 0;
}