树链剖分(重链剖分)

一.什么是树链剖分

      树剖是通过某些特殊的划分方法,将树上的节点划分到 不同 的链中,并且保证同一条链上的各个节点的 dfs序 连续,这样就可以用 线段树 对每一条链进行维护。从而方便解决 树上单条路径的查询和修改 问题。需注意:在使用树剖时树的形态不可以变化,只能改变路径上节点的信息 。树剖的核心在于 通过特殊的 划分方式 将 树上路径问题 转化为 区间 问题


     树剖一般解决的问题: 树上单条路径的查询和修改,子树的修改和查询
     对于一棵有 n n n 个节点的树,使用树剖可使每次 查询 或 修改 的复杂度达到 O ( l o g 2 n ) O(log^2n) O(log2n)
     与点分治的应用差别:1.树剖一般一次仅处理单条路径,而 点分治 则是同时考虑多条路径求全局最优或进行计数统计。
                                          2.树剖支持修改,可动态维护信息,而点分治一般解决无修改的问题。


二.怎样剖分

       首先需要明确几个概念:
        1.轻/重儿子:定义 s i z x siz_x sizx 表示以 x x x 为根的子树的大小。对于节点 u u u而言,它的儿子中 s i z siz siz 最大的即为它的重儿子,其余的为轻儿子。
       2. 轻/重边 : 节点 u u u 向它的重儿子 v v v 所连的边 ( u u u, v v v) 叫做重边,其余边为轻边。
       3. 重链:由重边相连所形成的链。

        接着提出两条性质:
        1. 若 v v v u u u 的轻儿子,则一定有 s i z v siz_v sizv < s i z u 2 \frac{siz_u}{2} 2sizu
        2.对于任意一个节点 x x x,从 x x x 到根的路径中所经过的 轻边重链 数都不超过 l o g 2 n 。 log_2n。 log2n

        性质的证明:
        1.可以考虑 反证法,若 s i z v siz_v sizv > = >= >= s i z u 2 \frac{siz_u}{2} 2sizu,则 s i z v siz_v sizv 一定大于 s i z u − 1 2 \frac{siz_u - 1}{2} 2sizu1,那么 v v v 一定比他的兄弟的 s i z siz siz 都要大,与 v v v u u u轻儿子相矛盾。故性质得证。
       2.通过性质 1 1 1可知道: 一条从 x x x 到根的路径所经过的轻边数量一定小于 l o g 2 n log_2n log2n(每次除以2)。因为 重链之间是靠轻边连接的, 因此 重链与轻边的数量之差的绝对值不超过1。(一个点也可看作一条重链) 因此 轻边重链 的数量均不超过 l o g 2 n log_2n log2n

       知道了这些性质,我们可以考虑构造算法:使每一条 重链上的点的 dfs序 连续。当查询或修改一条路径时, 在线段树上将该节点所在的重链上的信息 查询 或 修改,接着从该点跳到 上一条重链上。这样最多跳 l o g 2 n log_2n log2n 次,每次线段树上操作复杂度为 l o g 2 n log_2n log2n。故整体复杂度为 l o g 2 2 n log_2^2n log22n

轻重边划分的代码:

void dfs_chain(int x, int chain){//chain表示当前链的最上方端点 
	bel[x] = chain;
	dfn[x] = ++rst;//rst维护dfs序
	id[rst] = x;//反向映射,建线段树时用
	int k = 0;
	for(int i = head[x]; i; i = E[i].last){
		int v = E[i].v;
		if((dep[v] == dep[x] + 1) && (cnt[v] > cnt[k])) k = v;
	}
	if(k != 0) dfs_chain(k, chain);//保证划分完一条重链后再一次划分其它重链,同一条重链上的dfs序都是连续的
	for(int i = head[x]; i; i = E[i].last){
		int v = E[i].v;
		if((dep[v] == dep[x] + 1) && (v != k)) dfs_chain(v, v);
	}
}

s i z siz siz 代码:

void dfs(int x, int fa){
	dep[x] = dep[fa] + 1;
	cnt[x]++; fat[x] = fa;
	for(int i = head[x]; i; i = E[i].last){
		int v = E[i].v;
		if(v == fa) continue;
		dfs(v, x);
		cnt[x] += cnt[v];
	}
}

实现一条自下而上的路径 ( u , v ) (u, v) (u,v) 上的信息查询,修改。代码:

while(bel[u] != bel[v]){
	change(1, dfn[bel[u]], dfn[u],...);//修改
	ask(1, dfn[bel[u]], dfn[u]);//查询
	u = fa[bel[u]];//跳到上一条重链上 
}
change(1, dfn[v], dfn[u],...);//此时u和v在一条重链上,dfs序一定连续,直接修改/查询
ask(1, dfn[v], dfn[u]); 

树剖求 LCA 代码:

int lca(int u, int v){
	while(bel[u] != bel[v]){//只要不在一条重链上
		if(dep[bel[u]] > dep[bel[v]]) u = fa[bel[u]];//优先跳顶部深度大的
		else v = fa[bel[v]];
	}
	return dep[u] > dep[v] ? v : u;//此时深度小的就是LCA
}

注意:仅用树链剖分求LCA的复杂度为 O ( l o g 2 n ) O(log_2n) O(log2n)

知道这些以后,我们就可解决树上任意一条路径的查询/修改问题。只需求出点对 ( u , v ) (u, v) (u,v) 之间的 lca,然后分别处理 u u u 到 lca 和 v v v 到 lca 之间的路径即可。

三.例题

用树剖解决路径问题的难点在于 线段树中信息的维护。如果能在线段树中维护好信息,那么套上树剖板子即可。

1.[ZJOI2008] 树的统计题面

分析:在线段树上维护区间最大值和区间和即可。
CODE:

#include<bits/stdc++.h>//树链剖分 + 线段树 
using namespace std;

const int N = 3e4 + 10;
int read(){
	int x = 0, f = 1; char c = getchar();
	while(c < '0' || c > '9'){if(c == '-') f = -1; c = getchar();}
	while(c >= '0' && c <= '9'){x = (x << 1) + (x << 3) + (c ^ 48); c = getchar();}
	return x * f;
}

int n, q, u, v, dep[N], dfn[N], bel[N], head[N], w[N], cnt[N], id[N], tot, x, y, fat[N], a[N], rst;
char opt[10];
//string opt;

struct edge{
	int v, last;
}E[N * 2];

struct SeqmentTree{
	int l, r, sum, Maxn, tag, dat;
	#define l(x) t[x].l
	#define r(x) t[x].r
	#define sum(x) t[x].sum
	#define Maxn(x) t[x].Maxn
	#define tag(x) t[x].tag
	#define dat(x) t[x].dat
}t[N * 4];

void add(int u, int v){
	E[++tot].v = v;
	E[tot].last = head[u];
	head[u] = tot;
}

void dfs(int x, int fa){
	dep[x] = dep[fa] + 1;
	cnt[x]++; fat[x] = fa;
	for(int i = head[x]; i; i = E[i].last){
		int v = E[i].v;
		if(v == fa) continue;
		dfs(v, x);
		cnt[x] += cnt[v];
	}
}

void dfs_chain(int x, int chain){
	bel[x] = chain;
	dfn[x] = ++rst;//求dfs序
	id[rst] = x;
	int k = 0;
	for(int i = head[x]; i; i = E[i].last){
		int v = E[i].v;
		if((dep[v] == dep[x] + 1) && (cnt[v] > cnt[k])) k = v;
	}
	if(k != 0) dfs_chain(k, chain);
	for(int i = head[x]; i; i = E[i].last){
		int v = E[i].v;
		if((dep[v] == dep[x] + 1) && (v != k)) dfs_chain(v, v);
	}
}

void update(int p){
	sum(p) = sum(p << 1) + sum(p << 1 | 1);
	Maxn(p) = max(Maxn(p << 1), Maxn(p << 1 | 1));
}

void build(int p, int l, int r){
	l(p) = l, r(p) = r;
	if(l == r){
		sum(p) = Maxn(p) = a[id[l]];
		return ;
	}
	int mid = (l + r >> 1);
	build(p << 1, l, mid);
	build(p << 1 | 1, mid + 1, r);
	update(p);
}

void change(int p, int x, int y){
	if(l(p) == r(p)){
		sum(p) = Maxn(p) = y;
		return ;
	}
	int mid = (l(p) + r(p) >> 1);
	if(x <= mid) change(p << 1, x, y);
	else change(p << 1 | 1, x, y);
	update(p);
}

int lca(int x, int y){
	while(bel[x] != bel[y]){
		if(dep[bel[x]] > dep[bel[y]]){
			x = fat[bel[x]];
		}
		else y = fat[bel[y]];
	}
	return dep[x] > dep[y] ? y : x;
}

int ask_max(int p, int l, int r){
	if(l(p) >= l && r(p) <= r) return Maxn(p);
	int mid = (l(p) + r(p) >> 1);
	int res = -1e8;
	if(l <= mid) res = max(res, ask_max(p << 1, l, r));
	if(r > mid) res = max(res, ask_max(p << 1 | 1, l, r));
	return res;
}

int quary_max(int u, int v){//求 u -> v 路径上的最大值 
	int LCA = lca(u, v);//先求LCA
	int res = -1e8;
	while(bel[u] != bel[LCA]){//模板
		res = max(res, ask_max(1, dfn[bel[u]], dfn[u]));
		u = fat[bel[u]];
	}
	res = max(res, ask_max(1, dfn[LCA], dfn[u]));
	while(bel[v] != bel[LCA]){
		res = max(res, ask_max(1, dfn[bel[v]], dfn[v]));
		v = fat[bel[v]];
	}
	res = max(res, ask_max(1, dfn[LCA], dfn[v]));
	return res;
}

int ask_sum(int p, int l, int r){
	if(l(p) >= l && r(p) <= r) return sum(p);
	int mid = (l(p) + r(p) >> 1);
	int val = 0;
	if(l <= mid) val += ask_sum(p << 1, l, r);
	if(r > mid) val += ask_sum(p << 1 | 1, l, r);
	return val;
}

int quary_sum(int u, int v){
	int LCA = lca(u, v);
	int res = 0;
	while(bel[u] != bel[LCA]){
		res = res + ask_sum(1, dfn[bel[u]], dfn[u]);
		u = fat[bel[u]];
	}
	res += ask_sum(1, dfn[LCA], dfn[u]);
	while(bel[v] != bel[LCA]){
		res += ask_sum(1, dfn[bel[v]], dfn[v]);
		v = fat[bel[v]];
	}
	res += ask_sum(1, dfn[LCA], dfn[v]);
	res -= ask_sum(1, dfn[LCA], dfn[LCA]);//LCA这里被多算一次减去即可
	return res;
}

int main(){
	
	freopen("data.in", "r", stdin);
	freopen("data.out", "w", stdout);
	
	n = read();
	for(int i = 1; i < n; i++){
		u = read();
		v = read();
		add(u, v);
		add(v, u);
	}
	
	for(int i = 1; i <= n; i++) a[i] = read();
	
	
	dfs(1, 0);
	dfs_chain(1, 1);
	build(1, 1, rst);
	
	q = read();
	for(int i = 1; i <= q; i++){
		scanf("%s", opt + 1);
		if(opt[2] == 'H'){//修改 
			x = read(); y = read();
			change(1, dfn[x], y);
		}
		else if(opt[2] == 'M'){//求最大 
			x = read(); y = read();
			printf("%d\n", quary_max(x, y));
		}
		else{
			x = read(); y = read();
			printf("%d\n", quary_sum(x, y));
		}
	}
	
	return 0;
}

月下“毛景树”题面

分析:
树剖一般都是将的信息用线段树维护,因此我们考虑将 边权转化为点权。可以考虑将每一条边权都赋给 所连点深度较大的点。那么查询就是询问 从LCA断开的左右两条链,修改就是修改该边对应的节点即可。
CODE:

#include<bits/stdc++.h>
using namespace std;

const int N = 1e5 + 10;
int read(){
	int x = 0, f = 1; char c = getchar();
	while(c < '0' || c > '9'){if(c == '-') f = -1; c = getchar();}
	while(c >= '0' && c <= '9'){x = (x << 1) + (x << 3) + (c ^ 48); c = getchar();}
	return x * f;
}

int n, u[N], v[N], w[N], dfn[N], dep[N], fa[N], bel[N], siz[N], head[N], a[N], tot, rst, id[N], x, y, z, r[N];
char opt[10];
//string opt;
struct edge{
	int w, v, last;
}E[N * 2];

struct SeqmentTree{
	int l, r, tag_add, tag_cov, dat_add, dat_cov, Maxn;
	#define l(x) t[x].l
	#define r(x) t[x].r
	#define t_add(x) t[x].tag_add
	#define t_cov(x) t[x].tag_cov
	#define d_add(x) t[x].dat_add
	#define d_cov(x) t[x].dat_cov
	#define Maxn(x) t[x].Maxn
}t[N * 4];

void add(int u, int v, int w){
	E[++tot].v = v;
	E[tot].last = head[u];
	E[tot].w = w;
	head[u] = tot;
}

void dfs(int x, int fat){
	siz[x]++, dep[x] = dep[fat] + 1, fa[x] = fat;
	for(int i = head[x]; i; i = E[i].last){
		int v = E[i].v;
		if(v != fat){
			a[v] = E[i].w;//边权转化为点权 
			dfs(v, x);
			siz[x] += siz[v];
		}
	}
}

void dfs_chain(int x, int chain){
	bel[x] = chain; dfn[x] = ++rst, id[rst] = x;
	int k = 0;
	for(int i = head[x]; i; i = E[i].last){
		int v = E[i].v;
		if((dep[v] == dep[x] + 1) && (siz[v] > siz[k])) k = v;
	}
	if(k != 0) dfs_chain(k, chain);
	for(int i = head[x]; i; i = E[i].last){
		int v = E[i].v;
		if((dep[v] == dep[x] + 1) && (v != k)) dfs_chain(v, v);
	}
}

void update(int p){
	Maxn(p) = max(Maxn(p << 1), Maxn(p << 1 | 1));
}

void build(int p, int l, int r){
	l(p) = l, r(p) = r;
	if(l == r){
		Maxn(p) = a[id[l]];
		return ;
	}
	int mid = (l + r >> 1);
	build(p << 1, l, mid);
	build(p << 1 | 1, mid + 1, r);
	update(p);
}

int lca(int x, int y){
	while(bel[x] != bel[y]){
		if(dep[bel[x]] > dep[bel[y]]) x = fa[bel[x]];
		else y = fa[bel[y]];
	}
	return dep[x] > dep[y] ? y : x;
}

void spread(int p){
	if(t_cov(p)){
		t_cov(p << 1) = t_cov(p << 1 | 1) = 1;
		d_cov(p << 1) = d_cov(p << 1 | 1) = d_cov(p);
		Maxn(p << 1) = Maxn(p << 1 | 1) = d_cov(p);
		t_add(p << 1) = t_add(p << 1 | 1) = 0;
		d_add(p << 1) = d_add(p << 1 | 1) = 0;
		t_cov(p) = d_cov(p) = 0;
	}
	if(t_add(p)){
		d_add(p << 1) += d_add(p);
		d_add(p << 1 | 1) += d_add(p);
		t_add(p << 1) = 1;
		t_add(p << 1 | 1) = 1;
		Maxn(p << 1) += d_add(p);
		Maxn(p << 1 | 1) += d_add(p);
		d_add(p) = 0;
		t_add(p) = 0;
	}
}

int ask(int p, int l, int r){
	if(l(p) >= l && r(p) <= r) return Maxn(p);
	spread(p);
	int mid = (l(p) + r(p) >> 1);
	int res = 0;
	if(l <= mid) res = max(res, ask(p << 1, l, r));
	if(r > mid) res = max(res, ask(p << 1 | 1, l, r));
	return res;
}

int quary(int u, int v){
	int LCA = lca(u, v), res = 0;
	if(u != LCA){	
		while(bel[u] != bel[LCA]){
			res = max(res, ask(1, dfn[bel[u]], dfn[u]));
			u = fa[bel[u]];
		}
		if(u != LCA) res = max(res, ask(1, dfn[LCA] + 1, dfn[u]));
	}
	if(v != LCA){
		while(bel[v] != bel[LCA]){
			res = max(res, ask(1, dfn[bel[v]], dfn[v]));
			v = fa[bel[v]];
		}
		if(v != LCA) res = max(res, ask(1, dfn[LCA] + 1, dfn[v]));
	}
	return res;
}

void change(int p, int l, int r, int c, int T){
	if(T == 1){
		if(l(p) >= l && r(p) <= r){
			t_add(p) = 0; t_cov(p) = 1; d_add(p) = 0; d_cov(p) = c; Maxn(p) = c;
			return ;
		}
		spread(p);
		int mid = (l(p) + r(p) >> 1);
		if(l <= mid) change(p << 1, l, r, c, T);
		if(r > mid) change(p << 1 | 1, l, r, c, T);
		update(p);
	}
	else{
		if(l(p) >= l && r(p) <= r){
			t_add(p) = 1, d_add(p) += c, Maxn(p) = Maxn(p) + c;
			return ;
		}
		spread(p);
		int mid = (l(p) + r(p) >> 1);
		if(l <= mid) change(p << 1, l, r, c, T);
		if(r > mid) change(p << 1 | 1, l, r, c, T);
		update(p);
	}
}

void modify(int x, int y, int z, int T){
	int LCA = lca(x, y);
	if(x != LCA){
	   	while(bel[x] != bel[LCA]){
	   	    change(1, dfn[bel[x]], dfn[x], z, T);
			x = fa[bel[x]];	
		}
		if(x != LCA) change(1, dfn[LCA] + 1, dfn[x], z, T);
    }
	if(y != LCA){
	   	while(bel[y] != bel[LCA]){
	   	    change(1, dfn[bel[y]], dfn[y], z, T);
			y = fa[bel[y]];	
		}
		if(y != LCA) change(1, dfn[LCA] + 1, dfn[y], z, T);
	}
}

int main(){
	
	n = read();
	for(int i = 1; i < n; i++){
		u[i] = read(), v[i] = read(), w[i] = read();
		add(u[i], v[i], w[i]);
		add(v[i], u[i], w[i]);
	}
	
	dfs(1, 0);
	dfs_chain(1, 1);
	build(1, 1, rst);
	
	for(int i = 1; i < n; i++){
		int U = u[i], V = v[i];
		if(fa[U] == V) r[i] = U;
		else r[i] = V;
	}
	
	while(scanf("%s", opt + 1)){
		if(opt[2] == 't') break;
		
		if(opt[2] == 'a'){//求max 
            scanf("%d%d", &x, &y);
			printf("%d\n", quary(x, y));
		}
		else if(opt[2] == 'o'){//赋值 
            scanf("%d%d%d", &x, &y, &z);
			modify(x, y, z, 1); 
		}
		else if(opt[2] == 'h'){//第i条边 
			scanf("%d%d", &x, &y);
			change(1, dfn[r[x]], dfn[r[x]], y, 1);
		}
		else{//增加 
            scanf("%d%d%d", &x, &y, &z);
			modify(x, y, z, 2);
		}
	}
	
	return 0;
}

3. [SDOI2011] 染色题面

分析:
考虑用线段树用线段树维护区间内颜色段的数量,那么在子节点信息向上传递时有 s u m ( f a ) = s u m ( l s ) + s u m ( r s ) − ( r c o l o r ( l s ) = = l c o l o r ( r s ) ) sum(fa) = sum(ls) + sum(rs) - (rcolor(ls)==lcolor(rs)) sum(fa)=sum(ls)+sum(rs)(rcolor(ls)==lcolor(rs))。并且在我们查询两条相拼的链时,同样不可以仅仅相加,还应考虑若当前链的顶部与上方链的底部颜色相同时,答案数应减去1。在我们查询一段区间时也应遵循这一规则。
CODE:

#include<bits/stdc++.h>
using namespace std;

const int N = 1e5 + 10;
int read(){
	int x = 0, f = 1; char c = getchar();
	while(!isdigit(c)){if(c == '-') f = -1; c = getchar();}
	while(isdigit(c)){x = (x << 1) + (x << 3) + (c ^ 48); c = getchar();}
	return x * f;
}

int n, m, dfn[N], rst, bel[N], head[N], tot, u, v, color[N], id[N], siz[N], dep[N], x, y, z, fa[N];
struct edge{
	int v, last;
}E[N * 2];
struct SeqmentTree{
	int l, r, num, tag, col, lc, rc;//lc和rc分别记录左右端点的颜色 
	#define l(x) t[x].l
	#define r(x) t[x].r
	#define num(x) t[x].num
	#define tag(x) t[x].tag
	#define col(x) t[x].col
	#define lc(x) t[x].lc
	#define rc(x) t[x].rc
}t[N * 4];
char opt;

void add(int u, int v){
	E[++tot].v = v;
	E[tot].last = head[u];
	head[u] = tot;
}

void dfs(int x, int fat){
	dep[x] = dep[fat] + 1; fa[x] = fat; siz[x]++;
	for(int i = head[x]; i; i = E[i].last){
		int v = E[i].v;
		if(v == fat) continue;
		dfs(v, x);
		siz[x] += siz[v];
	}
}

void dfs_chain(int x, int chain){
	bel[x] = chain; dfn[x] = ++rst; id[rst] = x;
	int k = 0;
	for(int i = head[x]; i; i = E[i].last){
		int v = E[i].v;
		if((dep[v] == dep[x] + 1) && (siz[v] > siz[k])) k = v;
	}
	if(k) dfs_chain(k, chain);
	for(int i = head[x]; i; i = E[i].last){
		int v = E[i].v;
		if((dep[v] == dep[x] + 1) && (v != k)) dfs_chain(v, v);
	}
}

void build(int p, int l, int r){
	l(p) = l, r(p) = r;
	if(l == r){
		lc(p) = rc(p) = col(p) = color[id[l]];
		num(p) = 1;
		return ;
	}
	int mid = (l + r >> 1);
	build(p << 1, l, mid);
	build(p << 1 | 1, mid + 1, r);
	num(p) = num(p << 1) + num(p << 1 | 1);
	if(color[id[mid]] == color[id[mid + 1]]) num(p)--;
	lc(p) = lc(p << 1); rc(p) = rc(p << 1 | 1);
}

int lca(int x, int y){
	while(bel[x] != bel[y]){
		if(dep[bel[x]] > dep[bel[y]]) x = fa[bel[x]];
		else y = fa[bel[y]];
	}
	return dep[x] > dep[y] ? y : x;
}

void spread(int p){
	if(tag(p)){
		lc(p << 1) = rc(p << 1) = lc(p << 1 | 1) = rc(p << 1 | 1) = col(p << 1) = col(p << 1 | 1) = col(p);
		tag(p << 1) = tag(p << 1 | 1) = 1;
		num(p << 1) = num(p << 1 | 1) = 1;
		tag(p) = 0; col(p) = 0;
	}
}

int ask_col(int p, int x){
	if(l(p) == r(p)) return col(p);
	spread(p);
	int mid = (l(p) + r(p) >> 1);
	if(x <= mid) return ask_col(p << 1, x);
	if(x > mid) return ask_col(p << 1 | 1, x);
}

int ask_num(int p, int l, int r){
	if(l(p) >= l && r(p) <= r) return num(p);
	spread(p);
	int res = 0; int mid = (l(p) + r(p) >> 1);
	if(l <= mid) res += ask_num(p << 1, l, r);
	if(r > mid) res += ask_num(p << 1 | 1, l, r);
	if((l <= mid && r > mid) && ask_col(1, mid) == ask_col(1, mid + 1)) res--;
	return res;
}

void update(int p){
	num(p) = num(p << 1) + num(p << 1 | 1);
	int mid = (l(p) + r(p) >> 1);
	if(rc(p << 1) == lc(p << 1 | 1)) num(p)--;
	lc(p) = lc(p << 1); rc(p) = rc(p << 1 | 1);
}

void change(int p, int l, int r, int c){
	if(l(p) >= l && r(p) <= r){lc(p) = rc(p) = c; num(p) = 1; col(p) = c; tag(p) = 1; return ;}
	spread(p);
	int mid = (l(p) + r(p) >> 1);
	if(l <= mid) change(p << 1, l, r, c);
	if(r > mid) change(p << 1 | 1, l, r, c);
	update(p);
}

void modify(int u, int v, int c){
	int LCA = lca(u, v);
	while(bel[u] != bel[LCA]){
		change(1, dfn[bel[u]], dfn[u], c);
		u = fa[bel[u]];
	}
	change(1, dfn[LCA], dfn[u], c);
	while(bel[v] != bel[LCA]){
		change(1, dfn[bel[v]], dfn[v], c);
		v = fa[bel[v]];
	}
	change(1, dfn[LCA], dfn[v], c);
}

int quary(int u, int v){
	int LCA = lca(u, v), res = 0;
	while(bel[u] != bel[LCA]){
		res += ask_num(1, dfn[bel[u]], dfn[u]);//查询区间
		if(ask_col(1, dfn[fa[bel[u]]]) == ask_col(1, dfn[bel[u]])) res--;//减掉重复贡献
		u = fa[bel[u]];
	}
	res += ask_num(1, dfn[LCA], dfn[u]);
	while(bel[v] != bel[LCA]){
		res += ask_num(1, dfn[bel[v]], dfn[v]);
		if(ask_col(1, dfn[fa[bel[v]]]) == ask_col(1, dfn[bel[v]])) res--;
		v = fa[bel[v]];
	}
	res += ask_num(1, dfn[LCA], dfn[v]);
	res--;//都查询到了LCA, 相当于收尾相同。 最后 LCA 的贡献加了2次
	return res;
}

int main(){
	freopen("color.in", "r", stdin);
	freopen("color.out", "w", stdout);
	n = read(); m = read();
	for(int i = 1; i <= n; i++)
		color[i] = read();
	for(int i = 1; i < n; i++){
		u = read(), v = read();
		add(u, v);
		add(v, u);
	}
	
	dfs(1, 0);
	dfs_chain(1, 1);
	build(1, 1, rst);
	
	for(int i = 1; i <= m; i++){
		scanf("\n%c", &opt);
		if(opt == 'Q'){
			x = read(), y = read();
			printf("%d\n", quary(x, y));
		}
		else{
			x = read(), y = read(), z = read();
			modify(x, y, z);
		}
	}
	
	return 0;
}
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值