题意:
将路径上的点全部变成c
询问路径上的颜色段数。
树链剖分:
维护一下颜色段数,左端颜色,右端颜色,注意询问的时候要push_down(),还有合并时要注意判断,而不是简单的直接相加。
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<vector>
#define lc o<<1
#define rc o<<1|1
using namespace std;
const int maxn=100010;
int a[maxn];
vector<int>g[maxn];
int id[maxn],son[maxn],top[maxn],size[maxn],fa[maxn],dep[maxn];
int ql,qr,n,m,cnt;
void add(int x,int y){
g[x].push_back(y);
g[y].push_back(x);
}
void dfs1(int x){
size[x]=1;
for(int i=0;i<g[x].size();i++){
int v=g[x][i];
if(v==fa[x]) continue;
fa[v]=x;
dep[v]=dep[x]+1;
dfs1(v);
size[x]+=size[v];
if(size[v]>size[son[x]]) son[x]=v;
}
}
void dfs2(int x){
id[x]=++cnt;
if(!son[x]) return;
top[son[x]]=top[x];
dfs2(son[x]);
for(int i=0;i<g[x].size();i++){
int v=g[x][i];
if(id[v]) continue;
top[v]=v;
dfs2(v);
}
}
int sum[maxn<<2],cl[maxn<<2],cr[maxn<<2],all[maxn<<2];
void push_down(int o,int l,int r){
if(all[o]){
all[lc]=all[rc]=all[o];
sum[lc]=sum[rc]=1;
cl[lc]=cr[lc]=cl[rc]=cr[rc]=all[o];
all[o]=0;
}
}
void maintain(int o,int l,int r){
int mid=(l+r)>>1;
sum[o]=sum[lc]+sum[rc];
if(cr[lc]==cl[rc]) sum[o]--;
cl[o]=cl[lc];cr[o]=cr[rc];
}
void create_tree(int o,int l,int r,int x,int y){
if(l==r){
sum[o]=1;cl[o]=cr[o]=all[o]=y;
return;
}
push_down(o,l,r);
int mid=(l+r)>>1;
if(x<=mid) create_tree(lc,l,mid,x,y);
else create_tree(rc,mid+1,r,x,y);
maintain(o,l,r);
}
void modify(int o,int l,int r,int val){
if(ql<=l && qr>=r){
all[o]=val;cl[o]=cr[o]=val;sum[o]=1;
return;
}
int mid=(l+r)>>1;
push_down(o,l,r);
if(ql<=mid) modify(lc,l,mid,val);
if(qr>mid) modify(rc,mid+1,r,val);
maintain(o,l,r);
}
int find(int o,int l,int r,int pos){
if(l==r){
return cl[o];
}
push_down(o,l,r);
int mid=(l+r)>>1;
if(pos<=mid) return find(lc,l,mid,pos);
else return find(rc,mid+1,r,pos);
}
int Query(int o,int l,int r){
if(ql<=l && qr>=r){
return sum[o];
}
push_down(o,l,r);
int mid=(l+r)>>1;
int ans=0;
if(ql<=mid) ans+=Query(lc,l,mid);
if(qr>mid) ans+=Query(rc,mid+1,r);
if(ql<=mid && qr>mid && cl[rc]==cr[lc]) ans--;
return ans;
}
void change(int a,int b,int c){
while(top[a]!=top[b]){
if(dep[top[a]]<dep[top[b]]) swap(a,b);
ql=id[top[a]],qr=id[a];
modify(1,1,n,c);
a=fa[top[a]];
}
if(dep[a]<dep[b]) swap(a,b);
ql=id[b],qr=id[a];
modify(1,1,n,c);
}
int query(int a,int b){
int ans=0;
while(top[a]!=top[b]){
if(dep[top[a]]<dep[top[b]]) swap(a,b);
ql=id[top[a]],qr=id[a];
ans+=Query(1,1,n);
if(find(1,1,n,id[top[a]])==find(1,1,n,id[fa[top[a]]])){
ans--;
}
a=fa[top[a]];
}
if(dep[a]<dep[b]) swap(a,b);
ql=id[b],qr=id[a];
ans+=Query(1,1,n);
return ans;
}
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
for(int i=1;i<n;i++){
int x,y;
scanf("%d%d",&x,&y);add(x,y);
}
dep[1]=1;top[1]=1;
dfs1(1);
dfs2(1);
for(int i=1;i<=n;i++){
create_tree(1,1,n,id[i],a[i]);
}
while(m--){
char s[10];
int a,b,c;
scanf("%s",s);
if(s[0]=='C'){
scanf("%d%d%d",&a,&b,&c);
change(a,b,c);
}else{
scanf("%d%d",&a,&b);
printf("%d\n",query(a,b));
}
}
return 0;
}
^_^