模板题链接
到目前为止,我觉得树上启发式合并(dsu on tree)是我学到过的最神奇的算法。他用看似十分暴力的方式,跑出一个极低的复杂度(
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn))刚刚学的时候还有点不相信。。。
我们想象,我们现在合并两堆石头(一多一少),我们怎么才是最轻松的?是把多的那堆一个个搬到少的那堆去,还是把少的那堆一个个搬到多的那堆去?显然是后者。
就是说如果有两个个集合要合并,我们肯定是优先想到把小的那个合到大的那个去。
现在我们看看这个题。这个题其实也牵扯到合并的问题。即:从各个子节点合并到父节点。。就像线段树or树形dp一样。但是,明显,用树形dp,要存
d
p
[
i
]
[
j
]
dp[i][j]
dp[i][j]以
i
i
i为根的子树
j
j
j颜色的数量。。显然无论在时间还是空间上,都是无法接受的。这时,我收到合并石头的启发(其实是并查集):
第一步:我先像树剖那样,跑出重儿子。
第二步:我们dfs这个树,比如说,我们现在跑到了u节点。首先优先跑所有轻儿子,采用尾递归的方式,统计轻儿子的答案。期间维护一个通数组cnt,存的是以当前轻儿子为根的子树的颜色数量。跑完后将答案赋予这个轻儿子节点,然后清空cnt数组,开始跑下一个轻儿子。
第三步:轻儿子跑完后,我们开始跑重儿子。跑完后,注意:我们不在清空cnt数组,然后逐个再跑一遍各个轻儿子。将结果累计到cnt里面,这里就体现了我们所受的启发:小往大合并。
第四步:经第三步,我们cnt数组就存入了所有子树信息,然后结合u节点自身的信息,将这个答案赋予u就行啦。
这么样?感觉非常暴力因为我们每一个轻儿子都往下dfs了两遍,我们明明已经跑出这个子树的信息存入cnt了,可是还要清空再跑一边???不能忍啊。。。可是很遗憾,这个算法的复杂低的吓人。。具体证明我也不知道。
下面是ac代码:
#include <iostream>
#include <cstring>
#include <string>
#include <algorithm>
#include <queue>
#include <map>
#include <cmath>
#include <cstdio>
#define ll long long
using namespace std;
const int N = 1e5+5;
int ver[N<<1], he[N], ne[N<<1];
int tot = 1;
inline void add(int x, int y)
{
ver[++tot] = y;
ne[tot] = he[x];
he[x] = tot;
}
int siz[N], son[N];
void dfs_bson(int u, int f)
{
siz[u] = 1;
for (int i = he[u]; i; i = ne[i])
{
int y = ver[i];
if (y == f) continue;
dfs_bson(y, u);
siz[u] += siz[y];
if (siz[y] > siz[son[u]])
son[u] = y;
}
}
int col[N], cnt[N];
ll ans[N], sum;
int flag, maxx;
void _count(int u, int f, int val)
{
cnt[col[u]] += val;
if (cnt[col[u]] > maxx)
{
maxx = cnt[col[u]];
sum = col[u];
}
else if (cnt[col[u]] == maxx)
sum += col[u];
for (int i = he[u]; i; i = ne[i])
{
int y = ver[i];
if (y == f || y == flag) continue;
_count(y, u, val);
}
}
void dfs(int u, int f, bool _flag)
{
for (int i = he[u]; i; i = ne[i])
{
int y = ver[i];
if (y == f || y == son[u]) continue;
dfs(y, u, 0);
}
if (son[u])
{
dfs(son[u], u, 1);
flag = son[u];
}
_count(u, f, 1);
flag = 0;
ans[u] = sum;
if (!_flag)
{
_count(u, f, -1);
sum = maxx = 0;
}
}
int main()
{
int n;
scanf("%d", &n);
for (int i = 1; i <= n; i++)
scanf("%d", &col[i]);
for (int i = 1; i < n; i++)
{
int x, y;
scanf("%d%d", &x, &y);
add(x, y); add(y, x);
}
dfs_bson(1, 0);
dfs(1,0,0);
for (int i = 1; i <= n; i++)
{
printf("%I64d", ans[i]);
if (i != n) printf(" ");
}
return 0;
}