染色-树链剖分
题目描述
解法
先树剖,然后再线段树维护
考虑对每个节点维护
n
u
m
num
num值:表示当前区间颜色段数量
l
c
lc
lc值:表示当前区间最左端的颜色
r
c
rc
rc值:表示当前区间最右端的颜色
合并时,如果左儿子的
r
c
=
=
rc==
rc==右儿子的
l
c
lc
lc,那么合并后父节点的
n
u
m
num
num值还要减一;
查询时同样要注意,参见代码
代码实现
#include<bits/stdc++.h>
#define M 200009
using namespace std;
int nxt[M],to[M],first[M],tot,n,m,cnt,LC,RC;
int idx[M],num[M],top[M],dep[M],f[M],a[M],size[M],son[M];
struct tree{
int num,l,r,lc,rc,add;
}tr[M*4];
int read(){
int f=1,re=0;char ch;
for(ch=getchar();!isdigit(ch)&&ch!='-';ch=getchar());
if(ch=='-'){f=-1;ch=getchar();}
for(;isdigit(ch);ch=getchar()) re=(re<<3)+(re<<1)+ch-'0';
return re*f;
}
void add(int x,int y){
nxt[++tot]=first[x];
first[x]=tot;
to[tot]=y;
}
void dfs1(int u,int fa){
dep[u]=dep[fa]+1,f[u]=fa,size[u]=1;
for(int i=first[u];i;i=nxt[i]){
int v=to[i];
if(v==fa) continue;
dfs1(v,u);
size[u]+=size[v];
if(size[v]>size[son[u]]) son[u]=v;
}
}
void dfs2(int u,int tp){
num[u]=++cnt,idx[cnt]=u,top[u]=tp;
if(son[u]) dfs2(son[u],tp);
for(int i=first[u];i;i=nxt[i])
if(!num[to[i]]) dfs2(to[i],to[i]);
}
void pushup(int k){
tr[k].lc=tr[k<<1].lc;
tr[k].rc=tr[k<<1|1].rc;
tr[k].num=tr[k<<1].num+tr[k<<1|1].num-(tr[k<<1].rc==tr[k<<1|1].lc);
}
void build(int k,int l,int r){
tr[k].l=l,tr[k].r=r;
if(l==r){
tr[k].lc=tr[k].rc=a[idx[l]];
tr[k].num=1;
return;
}int mid=(l+r)>>1;
build(k<<1,l,mid);
build(k<<1|1,mid+1,r);
pushup(k);
}
void change(int k,int val){
tr[k].num=1;
tr[k].lc=tr[k].rc=val;
tr[k].add=val;
}
void pushdown(int k){
if(tr[k].add){
change(k<<1,tr[k].add);
change(k<<1|1,tr[k].add);
tr[k].add=0;
}
}
void modify(int k,int l,int r,int val){
if(tr[k].l>=l&&tr[k].r<=r) return change(k,val);
pushdown(k);
int mid=(tr[k].l+tr[k].r)>>1;
if(l<=mid) modify(k<<1,l,r,val);
if(r>mid) modify(k<<1|1,l,r,val);
pushup(k);
}
int query(int k,int l,int r){
if(tr[k].l==l) LC=tr[k].lc;
if(tr[k].r==r) RC=tr[k].rc;
if(tr[k].l>=l&&tr[k].r<=r) return tr[k].num;
pushdown(k);
int mid=(tr[k].l+tr[k].r)>>1;
if(l>mid) return query(k<<1|1,l,r);
if(r<=mid) return query(k<<1,l,r);
if(l<=mid||r>mid) return query(k<<1,l,r)+query(k<<1|1,l,r)-(tr[k<<1].rc==tr[k<<1|1].lc);
}
void update(int x,int y,int z){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
modify(1,num[top[x]],num[x],z);
x=f[top[x]];
}if(dep[x]<dep[y]) swap(x,y);
modify(1,num[y],num[x],z);
}
int solve(int x,int y){
int ans=0,vis1=0,vis2=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y),swap(vis1,vis2);
ans+=query(1,num[top[x]],num[x]);
if(vis1==RC) ans--;
vis1=LC;
x=f[top[x]];
}if(dep[x]<dep[y]) swap(x,y),swap(vis1,vis2);
ans+=query(1,num[y],num[x]);
if(vis1==RC) ans--;
if(vis2==LC) ans--;
return ans;
}
//void debug(int k){
// printf("%d %d %d %d %d\n",tr[k].l,tr[k].r,tr[k].lc,tr[k].rc,tr[k].num);
// if(tr[k].l==tr[k].r) return;
// debug(k<<1);debug(k<<1|1);
//}
int main(){
int x,y,z;char s;
n=read(),m=read();
for(int i=1;i<=n;i++) a[i]=read();
for(int i=1;i<n;i++){
x=read(),y=read();
add(x,y),add(y,x);
}dfs1(1,0),dfs2(1,1);
build(1,1,n);
for(int i=1;i<=m;i++){
cin>>s;
x=read(),y=read();
if(s=='C'){
z=read();
update(x,y,z);
}else printf("%d\n",solve(x,y));
}return 0;
}
做题启发
1,对于该类区间合并时,左右区间相互影响时,考虑本题分多种情况的做法