解题思路:
先链剖。
维护线段树时每个节点维护三个值:lc(左端点颜色),rc(右端点颜色),cnt(区间中颜色段数量)。
注意每次合并区间时(详见代码中update,query,Query函数),若左区间rc等于右区间lc是结果要减1;
修改时打标记即可,注意tag初始要赋值为-1,因为有color为0的情况。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
using namespace std;
int getint()
{
int i=0,f=1;char c;
for(c=getchar();(c<'0'||c>'9')&&c!='-';c=getchar());
if(c=='-')f=-1,c=getchar();
for(;c>='0'&&c<='9';c=getchar())i=(i<<3)+(i<<1)+c-'0';
return i*f;
}
const int N=100005;
int n,m,a[N];
int ecnt,first[N],next[N<<1],to[N<<1];
int fa[N],dep[N],son[N],size[N],top[N],pos[N],idx[N];
int tot,cnt[N<<2],lc[N<<2],rc[N<<2],tag[N<<2];
void add(int x,int y)
{
next[++ecnt]=first[x],first[x]=ecnt,to[ecnt]=y;
}
void dfs1(int u)
{
size[u]=1;
for(int e=first[u];e;e=next[e])
{
int v=to[e];
if(v==fa[u])continue;
fa[v]=u;
dep[v]=dep[u]+1;
dfs1(v);
size[u]+=size[v];
if(size[v]>size[son[u]])son[u]=v;
}
}
void dfs2(int u)
{
if(son[u])
{
pos[son[u]]=++tot;
idx[tot]=son[u];
top[son[u]]=top[u];
dfs2(son[u]);
}
for(int e=first[u];e;e=next[e])
{
int v=to[e];
if(v==fa[u]||v==son[u])continue;
pos[v]=++tot,idx[tot]=v,top[v]=v;
dfs2(v);
}
}
void update(int k)
{
lc[k]=lc[k<<1],rc[k]=rc[k<<1|1];
cnt[k]=cnt[k<<1]+cnt[k<<1|1];
if(rc[k<<1]==lc[k<<1|1])cnt[k]--;
}
void build(int k,int l,int r)
{
if(l==r)
{
lc[k]=rc[k]=a[idx[l]];
cnt[k]=1;
return;
}
int mid=l+r>>1;
build(k<<1,l,mid),build(k<<1|1,mid+1,r);
update(k);
}
void pushdown(int k)
{
lc[k<<1]=rc[k<<1]=tag[k<<1]=tag[k];
cnt[k<<1]=1;
lc[k<<1|1]=rc[k<<1|1]=tag[k<<1|1]=tag[k];
cnt[k<<1|1]=1;
tag[k]=-1;
}
void modify(int k,int l,int r,int x,int y,int c)
{
if(x<=l&&r<=y)
{
lc[k]=rc[k]=tag[k]=c;
cnt[k]=1;
return;
}
if(tag[k]!=-1)pushdown(k);
int mid=l+r>>1;
if(y<=mid)modify(k<<1,l,mid,x,y,c);
else if(x>mid)modify(k<<1|1,mid+1,r,x,y,c);
else modify(k<<1,l,mid,x,mid,c),modify(k<<1|1,mid+1,r,mid+1,y,c);
update(k);
}
int find(int k,int l,int r,int p)
{
if(l==r)return lc[k];
if(tag[k]!=-1)pushdown(k);
int mid=l+r>>1;
if(p<=mid)return find(k<<1,l,mid,p);
else return find(k<<1|1,mid+1,r,p);
}
int query(int k,int l,int r,int x,int y)
{
if(x<=l&&r<=y)return cnt[k];
if(tag[k]!=-1)pushdown(k);
int mid=l+r>>1;
if(y<=mid)return query(k<<1,l,mid,x,y);
else if(x>mid)return query(k<<1|1,mid+1,r,x,y);
else if(rc[k<<1]!=lc[k<<1|1])return query(k<<1,l,mid,x,mid)+query(k<<1|1,mid+1,r,mid+1,y);
else return query(k<<1,l,mid,x,mid)+query(k<<1|1,mid+1,r,mid+1,y)-1;
}
void Modify(int u,int v,int c)
{
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]])swap(u,v);
modify(1,1,n,pos[top[u]],pos[u],c);
u=fa[top[u]];
}
if(dep[u]>dep[v])swap(u,v);
modify(1,1,n,pos[u],pos[v],c);
}
int Query(int u,int v)
{
int res=0;
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]])swap(u,v);
res+=query(1,1,n,pos[top[u]],pos[u]);
if(find(1,1,n,pos[fa[top[u]]])==find(1,1,n,pos[top[u]]))res--;
u=fa[top[u]];
}
if(dep[u]>dep[v])swap(u,v);
res+=query(1,1,n,pos[u],pos[v]);
return res;
}
int main()
{
//freopen("lx.in","r",stdin);
//freopen("lx.out","w",stdout);
int x,y,z;char c;
memset(tag,-1,sizeof(tag));
n=getint(),m=getint();
for(int i=1;i<=n;i++)
a[i]=getint();
for(int i=1;i<n;i++)
{
x=getint(),y=getint();
add(x,y),add(y,x);
}
dfs1(1);
tot=top[1]=pos[1]=idx[1]=1;
dfs2(1);
build(1,1,n);
while(m--)
{
for(c=getchar();c!='Q'&&c!='C';c=getchar());
if(c=='Q')
{
x=getint(),y=getint();
cout<<Query(x,y)<<'\n';
}
else
{
x=getint(),y=getint(),z=getint();
Modify(x,y,z);
}
}
return 0;
}