这里主要就是每个节点要存中区间里的最左边的颜色l和最右边的颜色r,然后重写线段树上的‘+’,一道基础的树剖,但是一定要仔细!!!
#include<bits/stdc++.h>
#define lson k<<1,l,mid
#define rson k<<1|1,mid+1,r
using namespace std;
const int MX=1e5+9;
struct node{
int to,next;
}edge[MX<<1];
int tot=0,head[MX];
void add(int u,int v){
edge[tot].to=v;
edge[tot].next=head[u];
head[u]=tot++;
}
int n,m,a,b,cnt=0,fa[MX],siz[MX],son[MX],de[MX],colr[MX];
char order[3];
void dfs1(int u,int f,int deth){
fa[u]=f,siz[u]=1,de[u]=deth,son[u]=0;
for( int i=head[u] ; ~i ; i=edge[i].next ){
int v=edge[i].to;
if( v==f )
continue;
dfs1(v,u,deth+1);
siz[u]+=siz[v];
if( son[u]==0 || siz[son[u]]<siz[v] )
son[u]=v;
}
}
int top[MX],st[MX],pos[MX];
void dfs2(int u,int tp){
st[u]=++cnt;
pos[cnt]=u;
top[u]=tp;
if( son[u]>0 )
dfs2(son[u],tp);
for( int i=head[u] ; ~i ; i=edge[i].next ){
int v=edge[i].to;
if( v!=fa[u] && v!=son[u] )
dfs2(v,v);
}
}
struct Node{
int l=-1,r=-1,sum=0,laze=-1;
Node operator +(const Node &a)const{
Node ans;
if( l==-1 )
return a;
if( a.l==-1 ){
ans.l=l,ans.r=r,ans.sum=sum;
return ans;
}
else{
ans.l=l,ans.r=a.r;
ans.sum=sum+a.sum;
if( r==a.l )
ans.sum--;
return ans;
}
}
}t[MX<<2];
void pushup(int k){
t[k].l=t[k<<1].l;
t[k].r=t[k<<1|1].r;
t[k].sum=t[k<<1].sum+t[k<<1|1].sum;
if( t[k<<1].r==t[k<<1|1].l )
t[k].sum--;
return ;
}
void pushdown(int k){
if( t[k].laze!=-1 ){
t[k<<1].laze=t[k<<1|1].laze=t[k].laze;
t[k<<1].l=t[k<<1|1].l=t[k].laze;
t[k<<1].r=t[k<<1|1].r=t[k].laze;
t[k<<1].sum=t[k<<1|1].sum=1;
t[k].laze=-1;
}
}
void build(int k,int l,int r){
t[k].laze=-1;
if( l==r ){
t[k].l=t[k].r=colr[pos[l]];
t[k].sum=1;
return ;
}
int mid=(l+r)>>1;
build(lson);
build(rson);
pushup(k);
}
Node que(int k,int l,int r,int L,int R){
if( L<=l && r<=R )
return t[k];
pushdown(k);
int mid=(l+r)>>1;
Node ans;
if( L<=mid )
ans=que(lson,L,R)+ans;
if( mid<R )
ans=ans+que(rson,L,R);
return ans;
}
void update(int k,int l,int r,int L,int R,int val){
if( L<=l && r<=R ){
t[k].laze=val;
t[k].l=t[k].r=val;
t[k].sum=1;
return ;
}
pushdown(k);
int mid=(l+r)>>1;
if( L<=mid )
update(lson,L,R,val);
if( mid<R )
update(rson,L,R,val);
pushup(k);
}
int vsque(int u,int v){
int lu=0,lv=0,ans=0;
while( top[u]!=top[v] ){
if( de[top[u]]>de[top[v]] ){
Node ansu=que(1,1,n,st[top[u]],st[u]);
if( ansu.r==lu )
ans--;
ans+=ansu.sum;
lu=ansu.l;
u=fa[top[u]];
}
else{
Node ansv=que(1,1,n,st[top[v]],st[v]);
if( ansv.r==lv )
ans--;
ans+=ansv.sum;
lv=ansv.l;
v=fa[top[v]];
}
}
if( de[u]>de[v] ){
swap(u,v);
swap(lu,lv);
}
Node an=que(1,1,n,st[u],st[v]);
ans+=an.sum;
if( an.l==lu )
ans--;
if( an.r==lv )
ans--;
return ans;
}
void vsupdate(int u,int v,int w){
while( top[u]!=top[v] ){
if( de[top[u]]>de[top[v]] )
swap(u,v);
update(1,1,n,st[top[v]],st[v],w);
v=fa[top[v]];
}
if( de[u]>de[v] )
swap(u,v);
update(1,1,n,st[u],st[v],w);
}
int main()
{
// freopen("input.txt","r",stdin);
scanf("%d %d",&n,&m);
memset(head,-1,sizeof(head));
for( int i=1 ; i<=n ; i++ )
scanf("%d",&colr[i]);
for( int i=1 ; i<=n-1 ; i++ ){
scanf("%d %d",&a,&b);
add(a,b),add(b,a);
}
dfs1(1,0,0);
dfs2(1,1);
build(1,1,n);
while( m-- ){
scanf("%s",&order);
int u,v,w;
if( order[0]=='Q' ){
scanf("%d %d",&u,&v);
printf("%d\n",vsque(u,v));
}
else{
scanf("%d %d %d",&u,&v,&w);
vsupdate(u,v,w);
}
}
return 0;
}