学习笔记:树上启发式合并(DSU on tree)

树上启发式合并

启发式算法

         启发式算法是什么呢?

         启发式算法是基于人类的经验和直观感觉,对一些算法的优化。

         给个例子?

         最常见的就是并查集的按秩合并了,有带按秩合并的并查集中,合并的代码是这样的:

void merge(int x, int y) {
  int xx = find(x), yy = find(y);
  if (size[xx] < size[yy]) swap(xx, yy);
  fa[yy] = xx;
  size[xx] += size[yy];
}

         在这里,对于两个大小不一样的集合,我们将小的集合合并到大的集合中,而不是将大的集合合并到小的集合中。

         为什么呢?这个集合的大小可以认为是集合的高度(在正常情况下),而我们将集合高度小的并到高度大的显然有助于我们找到父亲。

         让高度小的树成为高度较大的树的子树,这个优化可以称为启发式合并算法。

树上启发式合并的含义

          树上启发式合并是一种 离线 解决树上问题的算法,算法主旨是通过维护 子树信息 树上问题,应用的范围很广。复杂度一般为 O ( n l o g 2 n ) O(nlog_2n) O(nlog2n),并且十分易于实现。

适用树上启发式合并的题目特征

  1. 可以 离线 操作。
  2. 需要维护 子树信息
  3. 询问的问题 不统一。 如多次询问在树上某点 x x x,距离 x x x k k k 的点的颜色种类数。因为询问中的 k k k 是会变化的,所以使用其它算法会比较难维护。
  4. 能够将维护的信息转化成 维护从儿子到父亲都不变的信息。 比如从儿子到父亲到某一点 x x x 的距离会改变,但是 深度 不会改变。 我们将维护距离信息变成维护深度信息。
  5. 可以在 O ( 1 ) O(1) O(1) O ( l o g 2 n ) O(log_2n) O(log2n) O ( n ) O(\sqrt{n}) O(n ) 的复杂度内 将一个点的信息加入维护的全局数据结构中

例题选讲

树上数颜色

给出一棵 n n n 个节点以 1 1 1 为根的树,节点 u 的颜色为 c u c_u cu,现在对于每个结点 u u u 询问 u u u 子树里一共出现了多少种不同的颜色。 n ≤ 2 × 1 0 5 n \leq 2 \times 10^5 n2×105

分析:

          对于这样的问题肯定要上数据结构。如果可以 离线询问 ,我们考虑 DSU on tree

          具体来说:我们首先对于每一个点 x x x 求出它的重儿子 b i g x big_x bigx。然后我们考虑 递归 解决每一个点内的询问。我们开一个全局的桶 c n t cnt cnt c n t i cnt_i cnti 表示当前 i i i 这种颜色的出现次数。 设 a n s x ans_x ansx 表示 x x x 节点的答案。那么我们按照一下顺序求出 a n s x ans_x ansx。 下文中 遍历 的含义是遍历子树。

          1. 首先遍历 x x x 的 轻(非重)儿子 u u u,计算 u u u 的答案,但 不保留遍历 u u u c n t cnt cnt 数组的影响

          2.遍历 x x x 的重儿子 v v v, 计算出 v v v 的答案,保留遍历 v v v 对答案的影响

          3. 再次遍历 x x x所有 轻儿子 u u u,将遍历结果加入 c n t cnt cnt 中,最后得到 x x x 的答案。

          复杂度证明:

          我们考虑这样做相当于把每一条 轻边 连接的子树枚举一遍。我们考虑一个点会被枚举到的次数等于 它到根的路径上的轻边数量。由于根到任意一点所经过的轻边数量都不超过 l o g 2 n log_2n log2n 条,因此一个点会被枚举 l o g 2 n log_2n log2n 次,所以时间复杂度是 O ( n l o g 2 n ) O(nlog_2n) O(nlog2n) 的。

          这里首先提供一个树上启发式合并的模板:

void add(int x){//加入信息
   ...
}

void del(int x){//删除信息
   ...
}

void dfs0(int x, int fa){// 提前处理一些信息
	L[x] = ++rk; sz[x] = 1; int id = 0;// 求出一个点子树的dfs序列左右端点
	ID[rk] = x;// 反向映射
	for(auto v : E[x]){
		if(v == fa) continue;
		dfs0(v, x);
		sz[x] += sz[v];
		if(sz[v] > sz[id]) id = v;
	}
	R[x] = rk;
	big[x] = id;//存重儿子编号
}

void dfs1(int x, int fa, bool keep){// keep = 0 -> 不保留信息, keep = 1 -> 保留信息
	for(auto v : E[x]){
		if(v == fa || v == big[x]) continue;
		dfs1(v, x, false);// flase 表示不保留  先解决轻儿子,轻儿子不保留
	}
	if(big[x]) dfs1(big[x], x, true);// 存在重儿子才递归,保留信息
	for(auto v : E[x]){
		if(v == fa || v == big[x]) continue;
		for(int i = L[v]; i <= R[v]; i++) add(...); // add函数代表把这个点加入全局维护信息的数据结构中
	}
	add(...);
	ans[x] = ...;
	if(!keep) for(int i = L[x]; i <= R[x]; i++) del(...);//从数据结构中删除
}

         在本题中,我们只需要维护一个 全局的桶 就可以维护信息了。其它的题目可能还要用一些其它的数据结构来维护信息,因此时间复杂度可能会多一个 l o g 2 n log_2n log2n

CODE:

#include<bits/stdc++.h>
#define pb push_back
using namespace std;
const int N = 1e5 + 10;
inline int read(){
	int x = 0, f = 1; char c = getchar();
	while(!isdigit(c)){if(c == '-') f = -1; c = getchar();}
	while(isdigit(c)){x = (x << 1) + (x << 3) + (c ^ 48); c = getchar();}
	return x * f;
}
int ID[N];
int cnt;
int n, col[N], L[N], R[N], rk, sz[N], big[N], T, x;
int ans[N], u, v, c[N];
vector< int > E[N];
void dfs0(int x, int fa){
	L[x] = ++rk; sz[x] = 1; int id = 0;
	ID[rk] = x;
	for(auto v : E[x]){
		if(v == fa) continue;
		dfs0(v, x);
		sz[x] += sz[v];
		if(sz[v] > sz[id]) id = v;
	}
	R[x] = rk;
	big[x] = id;
}
void add(int color){
	if(!col[color]) cnt++;
	col[color]++;
}
void del(int color){
	col[color]--;
	if(!col[color]) cnt--;
}
void dfs1(int x, int fa, bool keep){
	for(auto v : E[x]){
		if(v == fa || v == big[x]) continue;
		dfs1(v, x, false);// flase 表示不保留 
	}
	if(big[x]) dfs1(big[x], x, true);
	for(auto v : E[x]){
		if(v == fa || v == big[x]) continue;
		for(int i = L[v]; i <= R[v]; i++) add(c[ID[i]]);
	}
	add(c[x]);
	ans[x] = cnt;
	if(!keep) for(int i = L[x]; i <= R[x]; i++) del(c[ID[i]]);
}
int main(){
	n = read();
	for(int i = 1; i < n; i++){
		u = read(), v = read();
		E[u].pb(v); E[v].pb(u);
	}
	for(int i = 1; i <= n; i++){
		c[i] = read();
	}
	dfs0(1, 0);
	dfs1(1, 0, true);
	T = read();
	while(T--){
		x = read();
		printf("%lld\n", ans[x]);
	}
	return 0;
}

Tree and Queries

题目

简要题意:

给定一棵 n n n 个节点的树,根节点为 1 1 1。每个节点上有一个颜色 c i ci ci​。 m m m 次操作。操作有一种:
u u u k k k:询问在以 u u u 为根的子树中,出现次数 ≥ k ≥k k 的颜色有多少种。
2 ≤ n ≤ 1 0 5 2 \leq n \leq 10^5 2n105 1 ≤ m ≤ 1 0 5 1≤m≤10^5 1m105 1 ≤ c i , k ≤ 1 0 5 1≤ci,k≤10^5 1ci,k105

分析:

          首先,因为 没有修改操作,所以本题可以使用树上启发式合并。

          我们考虑维护一个全局 树状数组,用来快速询问 出现次数小于等于 c c c 的颜色数量数。同时维护一个桶 c o l o r i color_i colori,表示 i i i 这个颜色的数量。设颜色为 x x x 的点新增一个,那么将 c o l o r x color_x colorx 加1,同时在树状数组上把 c o l o r x color_x colorx 位置上减1, c o l o r x + 1 color_x + 1 colorx+1 位置加1。减少的情况类似。答案每次用树状数组查询就好了。

          时间复杂度 O ( n l o g 2 2 n ) O(nlog^{2}_{2}n) O(nlog22n)

CODE:

#include<bits/stdc++.h>
#define pb push_back
using namespace std;
typedef pair< int, int > PII;
const int N = 1e5 + 10;
inline int read(){
	int x = 0, f = 1; char c = getchar();
	while(!isdigit(c)){if(c == '-') f = -1; c = getchar();}
	while(isdigit(c)){x = (x << 1) + (x << 3) + (c ^ 48); c = getchar();}
	return x * f;
}
int sz[N], big[N], L[N], R[N], rk, dep[N], ID[N];
int n, m, col[N], c[N], u, v, x, k, color[N];
int ans[N];
vector< int > E[N];
vector< PII > vec[N];
int lowbit(int x){return x & -x;}
void add(int x, int y){for(; x < N; x += lowbit(x)) c[x] += y;}
int ask(int x){int res = 0; for(; x; x -= lowbit(x)) res += c[x]; return res;}
void dfs0(int x, int fa){
	L[x] = ++rk; sz[x] = 1;
	int id = 0; ID[rk] = x;
	for(auto v : E[x]){
		if(v == fa) continue;
		dfs0(v, x);
		sz[x] += sz[v];
		if(sz[v] > sz[id]) id = v;
	}
	big[x] = id; R[x] = rk;
}
void Add(int col){
	if(color[col]) add(color[col], -1);
	color[col]++;
	add(color[col], 1);
}
void Del(int col){
	add(color[col], -1);
	color[col]--;
	if(color[col]) add(color[col], 1);
}
void dfs1(int x, int fa, bool keep){
	for(auto v : E[x]){
		if(v == fa || v == big[x]) continue;
		dfs1(v, x, false);
	}
	if(big[x]) dfs1(big[x], x, true);
	Add(col[x]);
	for(auto v : E[x]){
		if(v == fa || v == big[x]) continue;
		for(int i = L[v]; i <= R[v]; i++){
			Add(col[ID[i]]);
		}
	}
	for(auto v : vec[x]){
		int k = v.first, id = v.second;
		ans[id] = ask(N - 1) - ask(k - 1);
	}
	if(!keep) for(int i = L[x]; i <= R[x]; i++) Del(col[ID[i]]);
}
int main(){
	n = read(), m = read();
	for(int i = 1; i <= n; i++) col[i] = read();
	for(int i = 1; i < n; i++){
		u = read(), v = read();
		E[u].pb(v); E[v].pb(u);
	}
	for(int i = 1; i <= m; i++){
		x = read(); k = read();
		vec[x].pb(make_pair(k, i));
	}
	dfs0(1, 0);
	dfs1(1, 0, true);
	for(int i = 1; i <= m; i++) printf("%d\n", ans[i]);
	return 0;
}

Arpa’s letter-marked tree and Mehrdad’s Dokhtar-kosh paths

题目

简要题意:

一棵根为 1 1 1 的树,每条边上有一个字符( a a a - v v v 22 22 22 种)。 一条简单路径被称为Dokhtar-kosh当且仅当路径上的字符经过重新排序后可以变成一个回文串。 求每个子树中最长的Dokhtar-kosh路径的长度。

分析:

          只询问全局的最长 Dokhtar-kosh 路径,那么这题很显然可以考虑使用 点分治 处理。但是要针对以每一个点为根的子树都要求出一条最长路径,我们考虑树上启发式合并。

          首先可以分析出来,合法的Dk路径一定是上面的所有 22 22 22 个字母都是偶数,或者只有 1 1 1 个字母是偶数,其它字母是奇数。因为只有这样才能在排序后成为一个回文串。所以 能否成为Dk路径只与每种字母的出现次数的奇偶性有关,我们可以使用状压。

          我们接着沿用点分治的思想:对于以 x x x 为根的子树而言,最长路径要么经过 x x x,要么不经过 x x x。不经过 x x x 的路径可以由 x x x 的儿子所在的子树的答案更新得到,我们现在处理 x x x 的子树内,经过 x x x 的最长Dokhtar-kosh 路径

          具体来讲,我们首先求出每一个点 u u u 到全局根的路径上的字母信息,新增一个字母改变奇偶性可以通过 异或 1 1 1 来实现。这样,重儿子子树内的信息传递给父亲时就不会改变了。然后我们维护一个全局桶 l e n len len l e n m a s k len_{mask} lenmask 代表当前为止,某一个点到根路径所有字母的奇偶性为 m a s k mask mask 的点的最大深度,然后因为最后合法的状态只有 23 23 23 种(全都是偶或者只有一个是奇),我们对于要新加入桶里的状态 t m a s k tmask tmask,枚举这 23 23 23 种状态 o k i ok_i oki,然后用 l e n t m a s k ⨁ o k i len_{tmask \bigoplus ok_i} lentmaskoki 和 当前的深度减去二倍当前子树的根 x x x 深度更新 a n s x ans_x ansx 就好了。 细节可能有点多。

CODE:

#include<bits/stdc++.h>
#define pb push_back
using namespace std;
const int N = 5e5 + 10;
typedef pair< int, int > PII;
char c;
int len[1 << 22], ok[25], dep[N];
int n, fa, sz[N], big[N], L[N], R[N], rk, mask[N], ID[N];
int ans[N];
const int INF = 1e8;
vector< PII > E[N];
void dfs0(int x, int fa){
	L[x] = ++rk; ID[rk] = x; dep[x] = dep[fa] + 1;
	sz[x] = 1; int id = 0;
	for(auto k : E[x]){
		int v = k.first, t = k.second;
		if(v == fa) continue;
		mask[v] = (mask[x] ^ t);
		dfs0(v, x);
		sz[x] += sz[v];
		if(sz[v] > sz[id]) id = v;
	}
	R[x] = rk;
	big[x] = id;
}
void get(int &x, int depth, int tm, int tl){
	for(int i = 0; i <= 22; i++) x = max(x, len[ok[i] ^ tm] - depth + tl - depth);
}
void add(int tm, int tl){
	len[tm] = max(len[tm], tl);
}
void del(int tm){
	len[tm] = -INF;
}
void dfs1(int x, int fa, bool keep){
	for(auto k : E[x]){
		int v = k.first;
		if(v == fa || v == big[x]) continue;
		dfs1(v, x, false);
		ans[x] = max(ans[x], ans[v]);//不经过x 
	}
	if(big[x]) dfs1(big[x], x, true), ans[x] = max(ans[x], ans[big[x]]);
	for(auto k : E[x]){
		int v = k.first;
		if(v == fa || v == big[x]) continue;
		for(int i = L[v]; i <= R[v]; i++) get(ans[x], dep[x], mask[ID[i]], dep[ID[i]]);//加上去 
		for(int i = L[v]; i <= R[v]; i++) add(mask[ID[i]], dep[ID[i]]);
	}
	get(ans[x], dep[x], mask[x], dep[x]);
	add(mask[x], dep[x]);
	if(!keep) for(int i = L[x]; i <= R[x]; i++) del(mask[ID[i]]);// 删去这些状态 
}
int main(){
	fill(len + 1, len + (1 << 22), -INF);
	for(int i = 0; i < 22; i++) ok[i] = (1 << i);
	ok[22] = 0;
	scanf("%d", &n);
	for(int i = 2; i <= n; i++){
	    scanf("%d\n%c", &fa, &c);
		E[fa].pb(make_pair(i, (1 << (c - 'a'))));	
	}
	dfs0(1, 0);
	dfs1(1, 0, true);
	for(int i = 1; i <= n; i++){
		printf("%d ", ans[i]);
	}
	return 0;
}

树上统计

题面

简要题意:

给定一棵 n n n 个节点的树。定义 T r e e [ L , R ] Tree[L, R] Tree[L,R] 表示为了使得 L ∼ R L \sim R LR 号点两两连通,最少需要选择的边的数量。
∑ l = 1 n ∑ r = l n T r e e [ L , R ] \sum_{l = 1}^{n}\sum_{r = l}^{n}Tree[L,R] l=1nr=lnTree[L,R]
n ≤ 1 0 5 n \leq 10^5 n105

分析:

          首先我们考虑如果 [ L , R ] [L, R] [L,R] 确定了,那么选择的边的集合就确定了。因此这道题本质上是一道计数题。

          经典思路,我们考虑 单边贡献,即一条边会被统计多少次。

          对于一条边 ( u , v ) (u, v) (u,v) 而言,不妨设 d e p u > d e p v dep_u > dep_v depu>depv,那么我们把 u u u 的子树里面的点在序列对应位置标记成 0 0 0 u u u 子树外面的点在序列对应位置标记成 1 1 1。那么实际上 ( u , v ) (u,v) (u,v) 会被算的次数等价于询问 有多少区间 [ L , R ] [L, R] [L,R] 满足 [ L , R ] [L, R] [L,R] 里面既有 0 0 0 又有 1 1 1

          正着统计肯定是不好做的,我们考虑正难则反,我们算出有多少区间里面只有 0 0 0 1 1 1,然后拿总区间数减去这个数量就好了。

          这个问题可以用 并查集set 解决。因为我们是在序列里面不断的把 1 1 1 变成 0 0 0。考虑多出的 0 0 0 与左右两边连续的 0 0 0 能多出多少不合法区间,以及少的 1 1 1 会减少多少不合法区间即可。 时间复杂度 O ( n l o g 2 2 n ) O(nlog^{2}_{2}n) O(nlog22n)。不会超时。

CODE:

#include<bits/stdc++.h>
#define pb push_back
using namespace std;
const int N = 1e5 + 10;
typedef pair< int, int > PII;
typedef long long LL;
inline int read(){
	int x = 0, f = 1; char c = getchar();
	while(!isdigit(c)){if(c == '-') f = -1; c = getchar();}
	while(isdigit(c)){x = (x << 1) + (x << 3) + (c ^ 48); c = getchar();}
	return x * f;
}
int sz[N], big[N], rk, L[N], R[N], ID[N], siz[N];
int n, u, v;
int bin[N];
LL res, all;
vector< int > E[N];
struct range{
	int l, r;
	friend bool operator < (range a, range b){
		return a.r < b.r;
	}
};
set< range > s; 
int Find(int x){return bin[x] == x ? x : bin[x] = Find(bin[x]);}
void dfs0(int x, int fa){
	sz[x] = 1; L[x] = ++rk;
	int id = 0; ID[rk] = x;
	for(auto v : E[x]){
		if(v == fa) continue;
		dfs0(v, x);
		if(sz[v] > sz[id]) id = v;
	}
	R[x] = rk; big[x] = id;
}
void Add(int x){//把x位置从0改成1
    bin[x] = x; siz[x] = 1;
	if(bin[x - 1] != -1){
		int f1 = Find(x - 1);//找一找 
		all = all - (1LL * siz[f1] * (siz[f1] + 1LL) / 2LL);
		siz[x] += siz[f1];
		bin[f1] = x;
	} 	
	if(bin[x + 1] != -1){
		int f2 = Find(x + 1);
		all = all - (1LL * siz[f2] * (siz[f2] + 1LL) / 2LL);
		siz[x] += siz[f2];
		bin[f2] = x;
	}
	all = all + (1LL * siz[x] * (siz[x] + 1LL) / 2LL);//正贡献 
	set< range >::iterator it = s.lower_bound(range{x, x});//所在段 
	int r = (*it).r, l = (*it).l;
	s.erase(it);
	all = all - ((r - l + 1) * (r - l + 2LL) / 2LL);
	if(x - 1 >= l) s.insert(range{l, x - 1}), all = all + ((1LL * x - l) * (1LL * x - l + 1LL) / 2LL);
	if(r >= x + 1) s.insert(range{x + 1, r}), all = all + ((1LL * r - x) * (1LL * r - x + 1LL) / 2LL);
}
void dfs1(int x, int fa, bool keep){// 考虑用并查集维护1, 用set维护0 
	for(auto v : E[x]){
		if(v == fa || v == big[x]) continue;
		dfs1(v, x, false);
	}
	if(big[x]) dfs1(big[x], x, true);
	Add(x);
	for(auto v : E[x]){
		if(v == fa || v == big[x]) continue;
		for(int i = L[v]; i <= R[v]; i++) Add(ID[i]);
	}
	if(x != 1) res += ((1LL * n * (n + 1)) / 2LL - all);
	if(!keep){
		for(int i = L[x]; i <= R[x]; i++) bin[ID[i]] = -1;
		s.clear(); s.insert(range{1, n});
		all = (1LL * n * (n + 1) / 2LL);
	}
}
int main(){
	n = read();
	for(int i = 1; i <= n; i++) bin[i] = -1;
	bin[0] = -1; bin[n + 1] = -1;
    s.insert(range{1, n}); all = (1LL * n * (n + 1) / 2LL);
	for(int i = 1; i < n; i++){
		u = read(), v = read();
		E[u].pb(v), E[v].pb(u);
	}
	dfs0(1, 0);
	dfs1(1, 0, true);
	printf("%lld\n", res);
	return 0;
}
/*
10
7 1
1 4
7 6
4 8
6 9
7 5
5 2
8 3
1 10
*/
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值