题解:在树上路径的查询一看就是树链剖分,然后查询是如何查询呢,如果我这条链的头结点颜色和其父节点颜色一样那么我们答案就要减一,更新线段树的用一个l,r记录最右端和最左端的颜色,如果左区间的右端点颜色和右区间的左端点颜色一样那么就要合并减去一段即可
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
using namespace std;
#define root 1,1,n
#define ls 2*rt
#define rs 2*rt+1
#define mid (L+R)/2
#define lson ls,L,mid
#define rson rs,mid+1,R
#define inf 0x3f3ff3f
const int mx = 1e5+5;
struct edge{
int v,next;
}e[mx<<2];
struct node{
int sum;
int l,r;
int lazy;
node operator+(const node a)const{
node c;
c.l = l;
c.r = a.r;
c.lazy = -1;
c.sum = sum+a.sum-(r==a.l&&r!=-1&&a.l!=-1);
return c;
}
}s[mx<<3];
int sz[mx];
int son[mx];
int top[mx];
int fa[mx];
int dep[mx];
int id[mx];
int val[mx];
int head[mx];
int n,q;
int tot,dfn;
void init(){
tot = dfn = 0;
memset(head,0,sizeof(head));
}
void add(int u,int v){
tot++;
e[tot].v = v;
e[tot].next = head[u];
head[u] = tot;
}
void built(int rt,int L,int R){
s[rt].l = -1;
s[rt].r = -1;
s[rt].lazy = -1;
s[rt].sum = 0;
if(L==R)
return;
built(lson);
built(rson);
}
void push_down(int rt){
if(s[rt].lazy!=-1){
s[ls].lazy = s[rs].lazy = s[rt].lazy;
s[ls].sum = s[rs].sum = 1;
s[ls].l = s[ls].r = s[rt].lazy;
s[rs].l = s[rs].r = s[rt].lazy;
s[rt].lazy = -1;
}
}
void update(int rt,int L,int R,int l,int r,int v){
if(L>=l&&R<=r){
s[rt].l = s[rt].r = s[rt].lazy = v;
s[rt].sum = 1;
return;
}
push_down(rt);
if(l>mid) update(rson,l,r,v);
else if(r<=mid) update(lson,l,r,v);
else update(lson,l,mid,v),update(rson,mid+1,r,v);
s[rt] = s[ls]+s[rs];
}
node query(int rt,int L,int R,int l,int r){
if(L>=l&&R<=r)
return s[rt];
push_down(rt);
if(l>mid) return query(rson,l,r);
else if(r<=mid) return query(lson,l,r);
else return query(lson,l,mid)+query(rson,mid+1,r);
}
void dfs(int u,int pre,int de){
fa[u] = pre;
dep[u] = de;
son[u] = 0;
sz[u] = 1;
for(int i = head[u]; i; i = e[i].next){
int v = e[i].v;
if(v!=pre){
dfs(v,u,de+1);
sz[u] += sz[v];
if(sz[son[u]] < sz[v]) son[u] = v;
}
}
}
void DFS(int u,int pre){
top[u] = pre;
id[u] = ++dfn;
update(root,dfn,dfn,val[u]);
if(son[u])
DFS(son[u],pre);
for(int i = head[u]; i; i = e[i].next){
int v = e[i].v;
if(v!=son[u]&&v!=fa[u])
DFS(v,v);
}
}
int calc(int a,int b){
int ans = 0;
node t,x;
while(top[a]!=top[b]){
if(dep[top[a]] < dep[top[b]])
swap(a,b);
t = query(root,id[top[a]],id[a]);
ans += t.sum;
a = fa[top[a]];
x = query(root,id[a],id[a]);
if(x.l == t.l)
ans--;
}
if(id[a]>id[b]) swap(a,b);
t = query(root,id[a],id[b]);
ans += t.sum;
return ans;
}
void change(int a,int b,int c){
while(top[a]!=top[b]){
if(dep[top[a]]<dep[top[b]])
swap(a,b);
update(root,id[top[a]],id[a],c);
a = fa[top[a]];
}
if(id[a]>id[b]) swap(a,b);
update(root,id[a],id[b],c);
}
int main(){
while(scanf("%d%d",&n,&q)!=EOF){
init();
for(int i = 1; i <= n; i++)
scanf("%d",&val[i]);
for(int i = 2; i <= n; i++){
int u,v;
scanf("%d%d",&u,&v);
add(u,v);
add(v,u);
}
built(root);
dfs(1,1,1);
DFS(1,1);
while(q--){
char ca[10];
int a,b,c;
scanf("%s",ca);
if(ca[0]=='Q'){
scanf("%d%d",&a,&b);
printf("%d\n",calc(a,b));
}
else{
scanf("%d%d%d",&a,&b,&c);
change(a,b,c);
}
}
}
return 0;
}