题意:树,路径染色,路径查询分了几段。
分析:
树链剖分套线段树,没写过,代码写得很乱,还犯了不少错,加了点注释,以后不能犯这种错了。
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define m ((L+R)>>1)
#define lc o<<1
#define rc o<<1|1
#define ls lc,L,m
#define rs rc,m+1,R
#define init 1,1,n
const int N = 100005; char op[1];
int n,q,e,x,y,z,tt,c[N*4],l[N*4],r[N*4],st[N*4],v[N],hd[N],nxt[N*2],to[N*2],d[N],p[N],tp[N],f[N],s[N],sz[N],pp[N];
void add(int x, int y) {to[++e] = y, nxt[e] = hd[x], hd[x] = e;}
//判断R > L和st!!!!
void pu(int o, int L, int R) {
if(R > L) c[o] = c[lc] + c[rc] + (r[lc] == l[rc] ? -1 : 0), l[o] = l[lc], r[o] = r[rc];
if(~st[o]) c[o] = 1, l[o] = r[o] = st[o];
}
void pd(int o) {if(~st[o]) st[lc] = st[rc] = st[o], st[o] = -1;}
void bd(int o, int L, int R) {
if(L == R) {
l[o] = r[o] = v[pp[L]], c[o] = 1; //不要写成v[L]或v[p[L]]
return;
}
bd(ls), bd(rs), pu(o, L, R);
}
int gt(int o, int L, int R, int p) {
if(~st[o]) return st[o];
if(L == R) return l[o];
if(p <= m) return gt(ls, p);
return gt(rs, p);
}
void up(int o, int L, int R, int l, int r, int p) {
if(l <= L && r >= R) st[o] = p;
else {
pd(o); if(l <= m) up(ls, l, r, p); else pu(ls);
if(r > m) up(rs, l, r, p); else pu(rs);
}
pu(o, L, R);
}
int qry(int o, int L, int R, int ll, int rr) {
if(~st[o]) return 1;
if(ll <= L && rr >= R) return c[o];
if(rr <= m) return qry(ls, ll, rr); if(ll > m) return qry(rs, ll, rr);
return qry(ls, ll, rr) + qry(rs, ll, rr) + (r[lc] == l[rc] ? -1 : 0);
}
void dfs1(int x) {
int mx = 0; sz[x] = 1;
for(int i = hd[x]; i; i = nxt[i]) if(!d[to[i]]) {
d[to[i]] = d[x] + 1, f[to[i]] = x;
dfs1(to[i]), sz[x] += sz[to[i]];
if(sz[to[i]] > mx) mx = sz[to[i]], s[x] = to[i];
}
}
void dfs2(int x) {
if(s[x]) tp[s[x]] = tp[x], p[s[x]] = ++tt, pp[tt] = s[x], dfs2(s[x]);
for(int i = hd[x]; i; i = nxt[i]) if(to[i] != f[x] && to[i] != s[x])
tp[to[i]] = to[i], p[to[i]] = ++tt, pp[tt] = to[i], dfs2(to[i]);
}
int qr(int x, int y) {
int ans = 0;
while(tp[x] != tp[y]) {
if(d[tp[x]] < d[tp[y]]) swap(x, y); //深度判断用tp[x]和tp[y]!!!
ans += qry(init, p[tp[x]], p[x]);
if(gt(init, p[tp[x]]) == gt(init, p[f[tp[x]]])) ans--;
x = f[tp[x]];
}
if(d[x] < d[y]) swap(x, y);
ans += qry(init, p[y], p[x]); //最后用p[y],不能用p[tp[x]]
return ans;
}
void upd(int x, int y, int z) {
while(tp[x] != tp[y]) {
if(d[tp[x]] < d[tp[y]]) swap(x, y);
up(init, p[tp[x]], p[x], z);
x = f[tp[x]];
}
if(d[x] < d[y]) swap(x, y);
up(init, p[y], p[x], z);
}
int main() {
memset(st, -1, sizeof st);
scanf("%d%d", &n, &q);
for(int i = 1; i <= n; i++) scanf("%d", &v[i]);
for(int i = 1; i < n; i++) scanf("%d%d", &x, &y), add(x, y), add(y, x);
d[1] = 1, dfs1(1), tp[1] = 1, p[1] = ++tt, pp[tt] = 1, f[1] = 1, dfs2(1), bd(init); //别忘了初始化某些数组
while(q--) {
scanf("%s%d%d", op, &x, &y);
if(op[0] == 'Q') printf("%d\n", qr(x, y)); else scanf("%d", &z), upd(x, y, z);
}
return 0;
}