修改时提取路径打tag就可以
询问时提取路径直接返回splay根节点的值
值得一提的是Pushup, 我们发现如果x 与 x的前驱颜色相同, 那么就算一条路径, 后继同理
为了高效地找到前驱后继的颜色, 我们维护一个lcol , rcol 表示前驱后继的颜色
t[x].lcol = t[ls].lcol
t[x].rcol = t[rs].rcol
如果 t[x].col = t[ls].rcol 说明x与前驱颜色相同
如果 t[x].col = t[rs].lcol 说明x与后继颜色相同
#include<bits/stdc++.h>
#define N 100050
#define ls t[x].ch[0]
#define rs t[x].ch[1]
using namespace std;
struct Node{
int ch[2],fa,val,col,lcol,rcol,tag,rev;
}t[N];
int n,m;
int read(){
int cnt=0; char ch=0;
while(!isdigit(ch))ch=getchar();
while(isdigit(ch))cnt=cnt*10+(ch-'0'),ch=getchar();
return cnt;
}
bool isRoot(int x){
int fa = t[x].fa; if(!fa) return true;
return t[fa].ch[0]!=x && t[fa].ch[1]!=x;
}
void Pushcol(int x,int c){ t[x].col = t[x].lcol = t[x].rcol = t[x].tag = c; t[x].val = 1;}
void Pushrev(int x){t[x].rev ^= 1; swap(ls,rs); swap(t[x].lcol, t[x].rcol);}
void Pushdown(int x){
if(t[x].tag){
if(ls) Pushcol(ls,t[x].tag);
if(rs) Pushcol(rs,t[x].tag);
t[x].tag = 0;
}
if(t[x].rev){
if(ls) Pushrev(ls);
if(rs) Pushrev(rs);
t[x].rev = 0;
}
}
void Pushpath(int x){
if(!isRoot(x)) Pushpath(t[x].fa);
Pushdown(x);
}
void Pushup(int x){
t[x].lcol = ls ? t[ls].lcol : t[x].col;
t[x].rcol = rs ? t[rs].rcol : t[x].col;
t[x].val = 1;
if(ls) t[x].val += t[ls].val - (t[ls].rcol == t[x].col);
if(rs) t[x].val += t[rs].val - (t[rs].lcol == t[x].col);
}
void rotate(int x){
int y = t[x].fa, z = t[y].fa;
int k = t[y].ch[1] == x;
if(!isRoot(y)) t[z].ch[t[z].ch[1]==y] = x;
t[x].fa = z;
t[y].ch[k] = t[x].ch[k^1];
t[t[x].ch[k^1]].fa = y;
t[x].ch[k^1] = y; t[y].fa = x;
Pushup(y); Pushup(x);
}
void Splay(int x){
Pushpath(x);
while(!isRoot(x)){
int y = t[x].fa, z = t[y].fa;
if(!isRoot(y))
(t[y].ch[0]==x) ^ (t[z].ch[0]==x) ? rotate(x) : rotate(y);
rotate(x);
} Pushup(x);
}
void Access(int x){
for(int y=0;x;y=x,x=t[x].fa)
Splay(x), rs = y, Pushup(x);
}
void Makeroot(int x){ Access(x); Splay(x); Pushrev(x);}
void Link(int x,int y){Makeroot(x); t[x].fa = y;}
int main(){
n = read(), m = read();
for(int i=1;i<=n;i++) t[i].col = read();
for(int i=1;i<n;i++){
int x = read(), y = read();
Link(x,y);
}
while(m--){
char s[3]; scanf("%s",s);
if(s[0] == 'Q'){
int x = read(), y = read();
Makeroot(x); Access(y); Splay(y);
printf("%d\n",t[y].val);
}
if(s[0] == 'C'){
int x = read(), y = read(), c = read();
Makeroot(x); Access(y); Splay(y);
Pushcol(y,c);
}
} return 0;
}