树链上区间合并的问题比区间修改要复杂,因为每一条重链在线段树上分布一般都是不连续的,所以在进行链上操作时要手动将其合并起来,维护两个端点值
处理时的方向问题:lca->u是一个方向,lca->v是另一个方向,到最后合并这两个放向时都看左端点即可
#include<cstring> #include<string> #include<iostream> #include<queue> #include<cstdio> #include<algorithm> #include<map> #include<cstdlib> #include<cmath> #include<vector> //#pragma comment(linker, "/STACK:1024000000,1024000000"); using namespace std; #define INF 0x3f3f3f3f #define maxn 200005 int v[maxn],fir[maxn],nex[maxn],e_max; int son[maxn],fa[maxn],color[maxn],col[maxn],pos[maxn],siz[maxn],deep[maxn],top[maxn]; int tot; int n,m; struct node { int l,r; int lc,rc; int num,tag; node() { lc=rc=-1; num=0; } } t[4*maxn]; void init_() { memset(son,-1,sizeof son); memset(siz,0,sizeof siz); memset(fir,-1,sizeof fir); e_max=0; tot=1; } void add_edge(int s,int t) { int e=e_max++; v[e]=t; nex[e]=fir[s]; fir[s]=e; } void dfs1(int k,int pre,int d) { deep[k]=d; siz[k]++; fa[k]=pre; for(int i=fir[k]; ~i; i=nex[i]) { int e=v[i]; if(e!=pre) { dfs1(e,k,d+1); siz[k]+=siz[e]; if(son[k]==-1||siz[son[k]]<siz[e]) son[k]=e; } } } void dfs2(int k,int sp) { top[k]=sp; pos[k]=tot++; col[pos[k]]=color[k]; if(son[k]==-1) return ; dfs2(son[k],sp); for(int i=fir[k]; ~i; i=nex[i]) { int e=v[i]; if(e!=fa[k]&&e!=son[k]) { dfs2(e,e); } } } node Merge(node A,node B) { if(A.num==0) return B; if(B.num==0) return A; node temp; temp.lc=A.lc; temp.rc=B.rc; if(A.rc==B.lc) temp.num=A.num+B.num-1; else temp.num=A.num+B.num; return temp; } inline void pushdown(int k) { if(t[k].tag==-1) return ; t[k<<1].tag=t[k<<1|1].tag=t[k].tag; t[k<<1].num=t[k<<1|1].num=1; t[k<<1].lc=t[k<<1].rc=t[k<<1|1].lc=t[k<<1|1].rc=t[k].tag; t[k].tag=-1; } inline void pushup(int k) { t[k].lc=t[k<<1].lc; t[k].rc=t[k<<1|1].rc; t[k].num=t[k<<1].num+t[k<<1|1].num; if(t[k<<1].rc==t[k<<1|1].lc) t[k].num--; } void init(int l,int r,int k) { t[k].l=l; t[k].r=r; t[k].tag=-1; if(l==r) { t[k].num=1; t[k].lc=col[l]; t[k].rc=col[r]; return ; } int mid=l+r>>1; init(l,mid,k<<1); init(mid+1,r,k<<1|1); pushup(k); } void update(int d,int l,int r,int k) { if(t[k].l==l&&t[k].r==r) { t[k].num=1; t[k].tag=d; t[k].lc=t[k].rc=d; return ; } pushdown(k); int mid=t[k].l+t[k].r>>1; if(r<=mid) update(d,l,r,k<<1); else if(l>mid) update(d,l,r,k<<1|1); else { update(d,l,mid,k<<1); update(d,mid+1,r,k<<1|1); } pushup(k); } node query(int l,int r,int k) { if(t[k].l==l&&t[k].r==r) { return t[k]; } pushdown(k); int mid=t[k].l+t[k].r>>1; if(r<=mid) return query(l,r,k<<1); else if(l>mid) return query(l,r,k<<1|1); else return Merge(query(l,mid,k<<1),query(mid+1,r,k<<1|1)); } void Query(int s,int t) { node L,R; int f1=top[s],f2=top[t]; while(f1!=f2) { if(deep[f1]<deep[f2]) swap(L,R),swap(f1,f2),swap(s,t); L=Merge(query(pos[f1],pos[s],1),L); s=fa[f1]; f1=top[s]; } if(deep[s]<deep[t]) swap(L,R),swap(s,t); L=Merge(query(pos[t],pos[s],1),L); if(L.lc==R.lc) L.num=L.num+R.num-1; else L.num=L.num+R.num; printf("%d\n",L.num); } void Change(int s,int t,int c) { int f1=top[s],f2=top[t]; while(f1!=f2) { if(deep[f1]<deep[f2]) swap(f1,f2),swap(s,t); update(c,pos[f1],pos[s],1); s=fa[f1]; f1=top[s]; } if(deep[s]>deep[t]) swap(s,t); update(c,pos[s],pos[t],1); } int main() { while(scanf("%d%d",&n,&m)!=EOF) { init_(); for(int i=1; i<=n; i++) { scanf("%d",&color[i]); } for(int i=1; i<n; i++) { int a,b; scanf("%d%d",&a,&b); add_edge(a,b); add_edge(b,a); } dfs1(1,-1,1); dfs2(1,1); init(1,tot-1,1); while(m--) { char s[5]; scanf("%s",s); if(s[0]=='Q') { int l,r; scanf("%d%d",&l,&r); Query(l,r); } else { int a,b,c; scanf("%d%d%d",&a,&b,&c); Change(a,b,c); } } } return 0; }