【题目链接】
【算法】
树链剖分
【代码】
本题,笔者求最近公共祖先并没有用树链剖分“往上跳”的方式,而是用倍增法。笔者认为这样比较好写,代码可读性
比较高
此外,笔者的线段树并没有用懒惰标记,只要当前访问节点的线段总数为1,那么就下传
#include<bits/stdc++.h>
using namespace std;
#define MAXLOG 18
const int MAXN = 1e5 + 10;
int i,n,m,timer,x,y,c,t;
int dep[MAXN],fa[MAXN],size[MAXN],son[MAXN],
dfn[MAXN],top[MAXN],val[MAXN],pos[MAXN],anc[MAXN][MAXLOG];
vector<int> e[MAXN];
char opt[10];
struct SegmentTree {
struct Node {
int l,r,sum,lcover,rcover;
} Tree[MAXN*4];
inline void push_up(int index) {
Tree[index].lcover = Tree[index<<1].lcover;
Tree[index].rcover = Tree[index<<1|1].rcover;
Tree[index].sum = Tree[index<<1].sum + Tree[index<<1|1].sum;
if (Tree[index<<1].rcover == Tree[index<<1|1].lcover) Tree[index].sum--;
}
inline void push_down(int index) {
Tree[index<<1].sum = Tree[index<<1|1].sum = 1;
Tree[index<<1].lcover = Tree[index<<1].rcover = Tree[index].lcover;
Tree[index<<1|1].lcover = Tree[index<<1|1].rcover = Tree[index].rcover;
}
inline void build(int index,int l,int r) {
int mid;
Tree[index].l = l;
Tree[index].r = r;
if (l == r) {
Tree[index].lcover = Tree[index].rcover = val[pos[l]];
Tree[index].sum = 1;
return;
}
mid = (l + r) >> 1;
build(index<<1,l,mid);
build(index<<1|1,mid+1,r);
push_up(index);
}
inline void modify(int index,int l,int r,int val) {
int mid;
if (Tree[index].l == l && Tree[index].r == r) {
Tree[index].lcover = Tree[index].rcover = val;
Tree[index].sum = 1;
return;
}
if (Tree[index].sum == 1) push_down(index);
mid = (Tree[index].l + Tree[index].r) >> 1;
if (mid >= r) modify(index<<1,l,r,val);
else if (mid + 1 <= l) modify(index<<1|1,l,r,val);
else {
modify(index<<1,l,mid,val);
modify(index<<1|1,mid+1,r,val);
}
push_up(index);
}
inline int query(int index,int l,int r) {
int mid,t;
if (Tree[index].l == l && Tree[index].r == r) return Tree[index].sum;
if (Tree[index].sum == 1) push_down(index);
mid = (Tree[index].l + Tree[index].r) >> 1;
if (mid >= r) return query(index<<1,l,r);
else if (mid + 1 <= l) return query(index<<1|1,l,r);
else {
t = 0;
if (Tree[index<<1].rcover == Tree[index<<1|1].lcover) t = 1;
return query(index<<1,l,mid) + query(index<<1|1,mid+1,r) - t;
}
}
inline int get(int index,int pos) {
int mid;
if (Tree[index].l == Tree[index].r) return Tree[index].lcover;
if (Tree[index].sum == 1) push_down(index);
mid = (Tree[index].l + Tree[index].r) >> 1;
if (mid >= pos) return get(index<<1,pos);
else return get(index<<1|1,pos);
}
} T;
inline void dfs1(int x) {
int i,y;
anc[x][0] = fa[x];
for (i = 1; i < MAXLOG; i++) {
if (dep[x] < (1 << i)) break;
anc[x][i] = anc[anc[x][i-1]][i-1];
}
size[x] = 1;
for (i = 0; i < e[x].size(); i++) {
y = e[x][i];
if (fa[x] != y) {
dep[y] = dep[x] + 1;
fa[y] = x;
dfs1(y);
size[x] += size[y];
if (size[y] > size[son[x]]) son[x] = y;
}
}
}
inline void dfs2(int x,int tp) {
int i,y;
dfn[x] = ++timer;
pos[timer] = x;
top[x] = tp;
if (son[x]) dfs2(son[x],tp);
for (i = 0; i < e[x].size(); i++) {
y = e[x][i];
if (fa[x] != y && son[x] != y)
dfs2(y,y);
}
}
inline int lca(int x,int y) {
int i,t;
if (dep[x] > dep[y]) swap(x,y);
t = dep[y] - dep[x];
for (i = 0; i <= MAXLOG - 1; i++) {
if (t & (1 << i))
y = anc[y][i];
}
if (x == y) return x;
for (i = MAXLOG - 1; i >= 0; i--) {
if (anc[x][i] != anc[y][i]) {
x = anc[x][i];
y = anc[y][i];
}
}
return anc[x][0];
}
inline void modify(int x,int y,int c) {
int tx = top[x],
ty = top[y];
while (tx != ty) {
T.modify(1,dfn[tx],dfn[x],c);
x = fa[tx]; tx = top[x];
}
T.modify(1,dfn[y],dfn[x],c);
}
inline int query(int x,int y) {
int tx = top[x],
ty = top[y],ans = 0;
while (tx != ty) {
ans += T.query(1,dfn[tx],dfn[x]);
if (T.get(1,dfn[tx]) == T.get(1,dfn[fa[tx]])) ans--;
x = fa[tx]; tx = top[x];
}
ans += T.query(1,dfn[y],dfn[x]);
return ans;
}
int main() {
scanf("%d%d",&n,&m);
for (i = 1; i <= n; i++) scanf("%d",&val[i]);
for (i = 1; i < n; i++) {
scanf("%d%d",&x,&y);
e[x].push_back(y);
e[y].push_back(x);
}
dfs1(1);
dfs2(1,1);
T.build(1,1,timer);
while (m--) {
scanf("%s",opt);
if (opt[0] == 'C') {
scanf("%d%d%d",&x,&y,&c);
t = lca(x,y);
modify(x,t,c); modify(y,t,c);
} else {
scanf("%d%d",&x,&y);
t = lca(x,y);
printf("%d\n",query(x,t)+query(y,t)-1);
}
}
return 0;
}