树链剖分+线段树
AC code:
#include <cstdio>
#include <vector>
using namespace std;
const int N=100010;
int n,m,tot,cnt1,cnt2;
int col[N],size[N],num[N],bl[N],top[N],last[N];
vector<int> G[N];
struct nod{
int l,r,len,c1,c2,tag;
nod *lc,*rc;
}pool[N<<2];
struct Segtree{
nod *root;
void build(nod **p,int L,int R){
*p=&pool[tot++];
(*p)->l=L;
(*p)->r=R;
(*p)->tag=-1;
if(L==R) return ;
int M=(L+R)>>1;
build(&(*p)->lc,L,M);
build(&(*p)->rc,M+1,R);
}
void clear(nod *p){
if(p->tag==-1) return ;
p->len=1;
p->c1=p->c2=p->tag;
if(p->l!=p->r) p->lc->tag=p->rc->tag=p->tag;
p->tag=-1;
}
void update(nod *p){
p->c1=p->lc->c1;
p->c2=p->rc->c2;
p->len=p->lc->len+p->rc->len;
if(p->lc->c2==p->rc->c1) p->len--;
}
void modify(nod *p,int L,int R,int c){
clear(p);
if(p->l==L&&p->r==R){
p->tag=c;
return ;
}
int M=(p->l+p->r)>>1;
if(R<=M) modify(p->lc,L,R,c);
else if(L>M) modify(p->rc,L,R,c);
else{
modify(p->lc,L,M,c);
modify(p->rc,M+1,R,c);
}
clear(p->lc);
clear(p->rc);
update(p);
}
int qlen(nod *p,int L,int R){
clear(p);
if(p->l==L&&p->r==R) return p->len;
int M=(p->l+p->r)>>1;
if(R<=M) return qlen(p->lc,L,R);
else if(L>M) return qlen(p->rc,L,R);
else{
int t=qlen(p->lc,L,M)+qlen(p->rc,M+1,R);
if(p->lc->c2==p->rc->c1) t--;
return t;
}
}
int qcol(nod *p,int pos){
clear(p);
if(p->l==p->r) return p->c1;
int M=(p->l+p->r)>>1;
if(pos<=M) return qcol(p->lc,pos);
else return qcol(p->rc,pos);
}
}T;
void DFS(int pre,int p){
size[p]=1;
for(int i=0;i<G[p].size();i++){
int q=G[p][i];
if(q==pre) continue;
DFS(p,q);
size[p]+=size[q];
}
}
void cut(int pre,int p,int number,int belong){
int q=0,maxs=0;
num[p]=number;bl[p]=belong;
for(int i=0;i<G[p].size();i++){
int next=G[p][i];
if(next!=pre&&size[next]>maxs){
q=next;
maxs=size[next];
}
}
if(q) cut(p,q,++cnt1,cnt2);
for(int i=0;i<G[p].size();i++){
int next=G[p][i];
if(next!=pre&&next!=q){
last[++cnt2]=p;
top[cnt2]=next;
cut(p,next,++cnt1,cnt2);
}
}
}
void swap(int *x,int *y){
int t=*x;*x=*y;*y=t;
}
void getlen(int u,int v){
int ans=0;
while(bl[u]!=bl[v]){
if(bl[u]>bl[v]) swap(&u,&v);
int tp=num[top[bl[v]]],la=num[last[bl[v]]];
ans+=T.qlen(T.root,tp,num[v]);
if(T.qcol(T.root,tp)==T.qcol(T.root,la)) ans--;
v=last[bl[v]];
}
if(num[u]>num[v]) swap(&u,&v);
ans+=T.qlen(T.root,num[u],num[v]);
printf("%d\n",ans);
}
void dye(int u,int v,int c){
while(bl[u]!=bl[v]){
if(bl[u]>bl[v]) swap(&u,&v);
T.modify(T.root,num[top[bl[v]]],num[v],c);
v=last[bl[v]];
}
if(num[u]>num[v]) swap(&u,&v);
T.modify(T.root,num[u],num[v],c);
}
int main(){
scanf("%d%d",&n,&m);
T.build(&T.root,1,n);
for(int i=1;i<=n;i++) scanf("%d",&col[i]);
for(int i=1;i<n;i++){
int a,b;
scanf("%d%d",&a,&b);
G[a].push_back(b);
G[b].push_back(a);
}
DFS(0,1);
cut(0,1,++cnt1,++cnt2);
for(int i=1;i<=n;i++) T.modify(T.root,num[i],num[i],col[i]);
for(int i=1;i<=m;i++){
char ch;
scanf("\n%c",&ch);
if(ch=='Q'){
int a,b;
scanf("%d%d",&a,&b);
getlen(a,b);
}
else{
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
dye(a,b,c);
}
}
return 0;
}