感觉非常的奇怪。。。。
以前rev的写法是
inline void rev(int u) { if (u) std::swap(lc , rc) , std::swap(lcol[u] , rcol[u]) , rv[lc] ^= 1 , rv[rc] ^= 1 , rv[u] = 0 ; }
就是说这个点的rv表示是否需要rev
以前这么写都没问题。。。。。
这次就爆炸了QAQ
改成了这个点的rv表示是否已经rev过了
这样
inline void rev(int u) { if (u) std::swap(lc , rc) , std::swap(lcol[u] , rcol[u]) , rv[u] ^= 1 ; }
然后就A掉了。。。。。感觉十分奇怪。。。。。
至于维护信息都是十分容易的东西
注意rev和set col的时候都要把lcol和rcol搞一搞
#include <bits/stdc++.h>
#define For(i,a,b) for(int i=a;i<b;i++)
#define lc ch[u][0]
#define rc ch[u][1]
#define maxn 100003
typedef int arr[maxn];
arr fa , rv , cv , sum , lcol , rcol , col , sta;
int ch[maxn][2] , n , m , top;
inline bool isrt(int u) { return (ch[fa[u]][0] != u) && (ch[fa[u]][1] != u) ; }
inline void rev(int u) { if (u) std::swap(lc , rc) , std::swap(lcol[u] , rcol[u]) , rv[u] ^= 1 ; }
inline void Col(int u , int v) { if (u) col[u] = cv[u] = lcol[u] = rcol[u] = v , sum[u] = 1 ; }
inline void ps(int u) {
if (!u) return ;
if (rv[u]) {
rev(lc) , rev(rc);
rv[u] = 0;
}
if (cv[u] != -1) {
Col(lc , cv[u]) , Col(rc , cv[u]);
cv[u] = -1;
}
}
inline void mt(int u) {
sum[u] = 1 , lcol[u] = rcol[u] = col[u];
if (lc) sum[u] += sum[lc] - (rcol[lc] == col[u]) , lcol[u] = lcol[lc];
if (rc) sum[u] += sum[rc] - (lcol[rc] == col[u]) , rcol[u] = rcol[rc];
}
inline void rot(int u) {
int f = fa[u] , g = fa[f] , l , r;
l = (ch[f][1] == u) , r = l ^ 1;
if (!isrt(f)) ch[g][ch[g][1] == f] = u;
fa[u] = g , fa[f] = u;if (ch[u][r]) fa[ch[u][r]] = f;
ch[f][l] = ch[u][r] , ch[u][r] = f;
mt(f) , mt(u);
}
inline void clear(int u) {
for(sta[top ++] = u;!isrt(u);u = fa[u]) sta[top ++] = fa[u];
for(;top;) ps(sta[-- top]);
}
inline void splay(int u) {
for(clear(u);!isrt(u);rot(u)) {
int f = fa[u] , g = fa[f];
if (!isrt(f)) rot(((ch[f][1] == u) ^ (ch[g][1] == f)) ? u : f);
}
mt(u);
}
inline void access(int u) {
int v = u;
for(int t = 0;u;t = u , u = fa[u])
splay(u) , rc = t ;
splay(v);
}
inline void mkrt(int u) {
access(u) ;
rev(u);
}
inline void split(int u , int v) {
mkrt(u) ;
access(v) ;
}
inline void link(int u , int v) {
mkrt(u) , fa[u] = v;
}
void Set(int u , int v , int c) {
split(u , v);
Col(v , c);
}
void Ask(int u , int v) {
split(u , v);
printf("%d\n" , sum[v]);
}
inline void input() {
scanf("%d%d" , &n , &m);
For(i , 1 , n + 1) scanf("%d" , col + i) , lcol[i] = rcol[i] = col[i] , cv[i] = -1 , sum[i] = 1;
For(i , 1 , n) {
int u , v;
scanf("%d%d" , &u , &v);
link(u , v);
}
}
inline void solve() {
For(i , 0 , m) {
char cmd[2];
int u , v , c;
scanf("%s%d%d" , cmd , &u , &v);
if (cmd[0] == 'C')
scanf("%d" , &c) , Set(u , v , c);
else if (cmd[0] == 'Q')
Ask(u , v);
}
}
int main() {
input();
solve();
return 0;
}