思路很简单,实现细节有点多,做了两个半小时。
学过树链剖分一眼就能看出来。
先树链剖分把图建好,线段树维护链,lg,rs表示每个区间的左右端点颜色,seg表示区间颜色段数,tag是染色标记。id是树链上点的编号,W(大写)是树链上点的编号对应的颜色,w(小写)是树链的线段树上叶子的颜色。
#include <bits/stdc++.h>
#define mid (l+r)/2
#define ls (k<<1)
#define rs (k<<1|1)
using namespace std;
const int N=4e5+10;
int seg[N],lg[N],rg[N],tag[N];
int tot=1,n,m;
int head[N],ver[N],nxt[N];
int a[N];
int son[N],sz[N],top[N],dep[N],fa[N],id[N],w[N],W[N];
int cnt;
void add(int u,int v){
ver[++tot]=v,nxt[tot]=head[u],head[u]=tot;
}
void dfs1(int x,int pre){
sz[x]=1;
fa[x]=pre;
int mxson=-1;
dep[x]=dep[pre]+1;
for(int i=head[x];i;i=nxt[i]){
int y=ver[i];
if(y==pre)continue;
dfs1(y,x);
sz[x]+=sz[y];
if(sz[y]>mxson)mxson=sz[y],son[x]=y;
}
}
void dfs2(int x,int topf){
top[x]=topf;
id[x]=++cnt;
W[id[x]]=a[x];
if(!son[x])return;
dfs2(son[x],topf);
for(int i=head[x];i;i=nxt[i]){
int y=ver[i];
if(y==fa[x]||y==son[x])continue;
dfs2(y,y);
}
}
void push_up(int k){
seg[k]=seg[ls]+seg[rs];
if(rg[ls]==lg[rs])seg[k]--;
lg[k]=lg[ls],rg[k]=rg[rs];
}
void build(int l,int r,int k){
if(l==r){
seg[k]=1;
lg[k]=rg[k]=W[l];
w[k]=W[l];
return ;
}
build(l,mid,ls);
build(mid+1,r,rs);
push_up(k);
}
void push_down(int k){
if(tag[k]){
tag[ls]=tag[rs]=tag[k];
lg[ls]=rg[ls]=lg[rs]=rg[rs]=tag[k];
seg[ls]=seg[rs]=1;
w[ls]=w[rs]=tag[k];
tag[k]=0;
}
}
void update(int a,int b,int c,int l,int r,int k){
if(a<=l&&b>=r){
seg[k]=1;
lg[k]=rg[k]=c;
tag[k]=c;
w[k]=c;
return;
}
push_down(k);
if(a<=mid)update(a,b,c,l,mid,ls);
if(b>mid)update(a,b,c,mid+1,r,rs);
push_up(k);
}
void update(int x,int y,int c){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
update(id[top[x]],id[x],c,1,n,1);
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
update(id[x],id[y],c,1,n,1);
}
int query(int a,int b,int l,int r,int k){
if(a<=l&&b>=r){
return seg[k];
}
push_down(k);
int ans=0;
if(a<=mid)ans+=query(a,b,l,mid,ls);
if(b>mid)ans+=query(a,b,mid+1,r,rs);
if(a<=mid&&b>mid&&rg[ls]==lg[rs]){
ans--;
}
return ans;
}
int ask(int a,int l,int r,int k){
if(l==r){
return w[k];
}
push_down(k);
if(a<=mid)return ask(a,l,mid,ls);
else return ask(a,mid+1,r,rs);
}
int query(int x,int y){
int ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
ans+=query(id[top[x]],id[x],1,n,1);
int t1=ask(id[top[x]],1,n,1),t2=ask(id[fa[top[x]]],1,n,1);
if(t1==t2){
ans--;
}
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
ans+=query(id[x],id[y],1,n,1);
return ans;
}
int main(){
std::ios::sync_with_stdio(0);
cin>>n>>m;
for(int i=1;i<=n;i++)cin>>a[i];
for(int i=1;i<n;i++){
int x,y;
cin>>x>>y;
add(x,y);
add(y,x);
}
dfs1(1,0);
dfs2(1,1);
build(1,n,1);
while(m--){
char c;
int x,y,z;
cin>>c>>x>>y;
if(c=='Q'){
cout<<query(x,y)<<endl;
}
else{
cin>>z;
update(x,y,z);
}
}
}