题目链接:https://www.lydsy.com/JudgeOnline/problem.php?id=2243
解题心得:
- 整体思想就是在询问和修改的时候从重链剖分的树上一段一段的跳,每一段用线段树维护。
- 线段树维护的时候注意合并的时候如果左右儿子合并起来的中间颜色相同总体的数目要减一,在重链上跳的时候要注意跨越一条链的时候要判断跨越的两点的颜色是否相同,相同要减一。
- 最后就是我自己写的一个bug了,主要是注意在找lca的时候是当前两个点所在链的深度大的先跳,并不是当前这两个点。
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5+100;
int num[maxn], n, m, Index, color[maxn];//color记录剖分后初始标号代表的颜色
struct Node {
int deep, num, fa, gr_fa, size;
}node[maxn];//原本的树,num记录剖分之后的标号,fa记录当前节点的父亲,gr_fa记录当前重链的首个节点
struct Node2 {
int lc, rc, cnt, lazy;
}node2[maxn];//线段树代表的树
vector <int> ve[maxn];
void init() {
scanf("%d%d",&n, &m);
for(int i=1;i<=n;i++) scanf("%d", &num[i]);
for(int i=1;i<n;i++) {
int a, b; scanf("%d%d", &a, &b);
ve[a].push_back(b);
ve[b].push_back(a);
}
}
int get_deep_cnt(int pre, int now, int dep) {
node[now].size = 1;
node[now].deep = dep;
node[now].fa = pre;
for(int i=0;i<ve[now].size();i++) {
int v = ve[now][i];
if(v == pre) continue;
node[now].size += get_deep_cnt(now, v, dep+1);
}
return node[now].size;
}
void dfs(int pre, int now, int w) {
if(w == 0) node[now].gr_fa = now;
else node[now].gr_fa = node[pre].gr_fa;
node[now].num = ++Index;
color[Index] = num[now];
int Max = 0, pos = -1;
for(int i=0;i<ve[now].size();i++) {
int v = ve[now][i];
if(v == pre) continue;
if(node[v].size > Max) {
Max = node[v].size;
pos = v;
}
}
if(pos == -1) return ;
dfs(now, pos, 1);
for(int i=0;i<ve[now].size();i++) {
int v = ve[now][i];
if(v == pre || v == pos) continue;
dfs(now, v, 0);
}
}
void update(int root) {
int chl = root<<1, chr = root<<1|1;
node2[root].cnt = node2[chl].cnt + node2[chr].cnt;
node2[root].lc = node2[chl].lc;
node2[root].rc = node2[chr].rc;
if(node2[chl].rc == node2[chr].lc) node2[root].cnt--;
}
void build_tree(int root, int l, int r) {
node2[root].lazy = node2[root].cnt = node2[root].lc = node2[root].rc = 0;
if(l == r) {
node2[root].cnt = 1;
node2[root].lc = color[l];
node2[root].rc = color[l];
return ;
}
int mid = l + r >> 1;
int chl = root<<1, chr = root<<1|1;
build_tree(chl, l, mid);
build_tree(chr, mid+1, r);
update(root);
}
void pushdown(int root) {
if(node2[root].lazy == 0) return ;
int chl = root<<1, chr = root<<1|1;
node2[chl].lazy = node2[root].lazy;
node2[chr].lazy = node2[root].lazy;
node2[chl].cnt = node2[chr].cnt = 1;
node2[chl].lc = node2[chr].lc = node2[chl].rc = node2[chr].rc = node2[root].lazy;
node2[root].lazy = 0;
}
void change(int root, int l,int r, int ql, int qr, int c) {
if(l == ql && r == qr) {
node2[root].lc = node2[root].rc = c;
node2[root].cnt = 1;
node2[root].lazy = c;
return ;
}
pushdown(root);
int mid = l + r >>1;
int chl = root<<1, chr = root<<1|1;
if(qr <= mid) change(chl, l, mid, ql, qr, c);
else if(ql > mid) change(chr, mid+1, r, ql, qr, c);
else {
change(chl, l, mid, ql, mid, c);
change(chr, mid+1, r, mid+1, qr, c);
}
update(root);
}
void change(int p1, int p2, int c) {
while(node[p1].gr_fa != node[p2].gr_fa) {
if(node[node[p1].gr_fa].deep < node[node[p2].gr_fa].deep) swap(p1, p2);
int r = node[p1].num;
int l = node[node[p1].gr_fa].num;
change(1, 1, Index, l, r, c);
p1 = node[node[p1].gr_fa].fa;
}
if(node[p1].num > node[p2].num) swap(p1, p2);
change(1, 1, Index, node[p1].num, node[p2].num, c);
}
int find_va(int root, int l, int r,int pos) {
if(l == r) return node2[root].lc;
pushdown(root);
int mid = l + r >> 1;
int chl = root<<1, chr = root<<1|1;
if(pos <= mid) return find_va(chl, l, mid, pos);
else return find_va(chr, mid+1, r, pos);
}
int query(int root, int l, int r, int ql, int qr) {
if(l == ql && r == qr) {
return node2[root].cnt;
}
pushdown(root);
int mid = l + r >> 1;
int chl = root<<1, chr = root<<1|1;
if(qr <= mid) return query(chl, l, mid, ql, qr);
else if(ql > mid) return query(chr, mid+1, r, ql, qr);
else {
return query(chl, l, mid, ql, mid) + query(chr, mid+1, r, mid+1, qr)
- (find_va(1, 1, Index, mid) == find_va(1, 1, Index, mid+1));
}
}
int query(int p1, int p2) {
int pre_node, sum = 0;
while(node[p1].gr_fa != node[p2].gr_fa) {
if(node[node[p1].gr_fa].deep < node[node[p2].gr_fa].deep) swap(p1, p2);
int r = node[p1].num;
int l = node[node[p1].gr_fa].num;
sum += query(1, 1, Index, l, r);
pre_node = p1;
p1 = node[node[p1].gr_fa].fa;
sum -= (find_va(1, 1, Index, node[node[pre_node].gr_fa].num) == find_va(1, 1, Index, node[p1].num));
}
if(node[p1].num > node[p2].num) swap(p1, p2);
sum += query(1, 1, Index, node[p1].num, node[p2].num);
return sum;
}
int main() {
// freopen("1.in.txt", "r", stdin);
init();
get_deep_cnt(1, 1, 1);
dfs(-1, 1, 0);
build_tree(1, 1, Index);
while(m--) {
char ope[5]; scanf("%s", ope);
if (ope[0] == 'Q') {
int p1, p2; scanf("%d%d",&p1, &p2);
printf("%d\n", query(p1, p2));
} else {
int p1, p2, c; scanf("%d%d%d", &p1, &p2, &c);
change(p1, p2, c);
}
}
return 0;
}