dsu on tree思想简述
算法思想
树上启发式合并(dsu on tree)通常用于解决不带修改的树上子树查询的问题。一般情况下会在一次dfs中将所有子树的答案计算出来,然后对于每一次询问进行 O ( 1 ) O(1) O(1)查询。
整体算法利用dfs的思想,在dfs的同时不断更新当前答案。首先进行重链剖分,同时根据dfs序对整棵树进行重编号。对于每一个结点,暴力枚举自己的每一个轻儿子对这个结点产生的贡献,然后计算重儿子对这个结点产生的贡献。这样就可以算出这个结点的答案。然后暴力消除轻儿子带来的贡献,留下重儿子的贡献。
所谓留下重儿子的贡献,就是在dfs的过程中维护一个全局数组来记录答案(来解决空间不足的问题)。对于一个结点,我们需要用dfs去寻找每一个子结点产生的贡献。如果是重儿子,那么我们将这个贡献保留到全局当中去,因为重儿子的子树很大,查询很耗时,将其信息保存起来可以减少很多不必要的查询。如果是轻儿子,那么在查询完轻儿子时不在全局保留贡献,下次再用的时候,就下次再来搜,因为轻儿子数量很少,所以搜索时间不会太长。
重要性质:一个节点到根的路径上轻边个数不会超过 log n \log n logn条。
算法整体时间复杂度: O ( n log n ) O(n\log n) O(nlogn)
算法流程
- 先对整棵树根据dfs序进行重编号,同时进行树链剖分。
- dfs遍历整棵树。对于每一个结点,先依次遍历每一个轻儿子,将轻儿子的贡献加入到全局贡献中去;如果有重儿子,再遍历重儿子,将重儿子的贡献遍历到全局中去。查询完所有子结点后统计当前结点的答案。如果这个点是父结点的轻儿子,那么消除这个点的所带来的所有贡献。
例题
Luogu U41492 - 树上数颜色
题目大意
给一棵有
n
n
n个点的、根为1的树,有
m
m
m次查询,每次询问某棵子树的颜色种类数。
(
1
≤
n
,
m
≤
1
0
5
)
(1\leq n,m \leq 10^5)
(1≤n,m≤105)。
题目解法
树上启发式合并的模板题。先对整个树进行树链剖分,同时对于每一个结点根据dfs序重编号。(下面代码中的odfs)
我们全局维护一个数组cnt,cnt[i]表示当前全局状态第 i i i个颜色的出现次数;维护一个变量ctot表示当前cnt数组中数量大于0的颜色数量。
然后对整棵树进行dfs,对于每一个结点:
- 对于当前结点的每一个轻儿子进行dfs。
- 如果当前结点有重儿子,对重儿子进行dfs。
- 对所有轻儿子,统计以其为根结点的子树的颜色数量。我们对树进行dfs序编号的目的,就是让子树的编号连续。这样暴力枚举每一个结点,更新全局信息。
- 把结点u的答案更新到全局信息。
- 通过当前全局的信息更新当前结点的答案。
- 如果这个点不是父结点的重儿子,则将这个点的子树的贡献从全局中删去。
一轮dfs后,一棵树的所有结点的答案就全部计算好了。接下来的 m m m次查询,直接 O ( 1 ) O(1) O(1)输出结果即可。
代码
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const LL N = 100050;
LL n, m, en = 0, tot = 0, ctot = 0;
LL front[N], son[N], sz[N], dfn[N], nod[N], l[N], r[N];
LL cnt[N], ans[N], c[N];
struct Edge {
LL v, next;
}e[N * 4];
void addEdge(LL u, LL v) {
e[++en] = {v, front[u]};
front[u] = en;
}
void odfs(LL u, LL fa) {
l[u] = ++tot;
nod[tot] = u;
sz[u] = 1;
LL ms = 0;
for (LL i = front[u]; i; i = e[i].next) {
LL v = e[i].v;
if (v == fa) continue;
odfs(v, u);
sz[u] += sz[v];
if (ms < sz[v]) {
ms = sz[v]; son[u] = v;
}
}
r[u] = tot;
}
void add(LL u) {
if (cnt[c[u]] == 0) ++ctot;
++cnt[c[u]];
}
void del(LL u) {
--cnt[c[u]];
if (cnt[c[u]] == 0) --ctot;
}
void dfs(LL u, LL fa, bool kp) {
for (LL i = front[u]; i; i = e[i].next) {
LL v = e[i].v;
if (v == fa or v == son[u]) continue;
dfs(v, u, false);
}
if (son[u]) dfs(son[u], u, true);
for (LL i = front[u]; i; i = e[i].next) {
LL v = e[i].v;
if (v == fa or v == son[u]) continue;
for (LL j = l[v]; j <= r[v]; ++j) {
add(nod[j]);
}
}
add(u);
ans[u] = ctot;
if (kp == false) {
for (LL i = l[u]; i <= r[u]; ++i) {
del(nod[i]);
}
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0); cout.tie(0);
cin >> n;
for (LL i = 1; i <= n; ++i) {
front[i] = son[i] = 0;
}
for (LL i = 1; i < n; ++i) {
LL x, y;
cin >> x >> y;
addEdge(x, y); addEdge(y, x);
}
for (LL i = 1; i <= n; ++i) {
cin >> c[i];
}
odfs(1, 0);
dfs(1, 0, false);
cin >> m;
for (LL i = 1; i <= m; ++i) {
LL x;
cin >> x;
cout << ans[x] << endl;
}
return 0;
}