题面:BZOJ2243 Luogu2486
线段树维护树上DFS序带修改求区间连续段数目
上面这么长一串可以简单缩成:
树剖维护区间连续段
也可以说成线段树上树(雾
首先外面的那层树剖应该很好套,内层区间线段树搞一搞就好了
线段树维护区间连续段嘛。。。应该来说还是比较好搞吧(虽然搞了我一个下午QAQ)
每个线段树上节点维护t(区间段数),lc(左端颜色),rc(右端颜色)
然后合并时从下传上如果左边段的右端颜色和右边段左端颜色相同,全段t-1
这同时也适用与树链合并答案时候的操作
我们先找出该条重链的top节点颜色和其父亲颜色,再按照上面规律合并答案即可
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<iostream>
#include<cstdlib>
#include<string>
#include<ctime>
#include<queue>
#include<climits>
using namespace std;
int n,nedge=0,v[100001],p[200001],nex[200001],head[200001];
int fa[100001],deep[100001],s[100001],son[100001],top[100001];
int sx[100001],xs[100001],ne=0;
int lt[400001],rt[400001],t[400001],add[400001],ls[400001],rs[400001];
inline void addedge(int a,int b){p[++nedge]=b;nex[nedge]=head[a];head[a]=nedge;}
inline void dfs(int x,int Fa,int dep){
fa[x]=Fa;deep[x]=dep;s[x]=1;
for(int k=head[x];k;k=nex[k]){
if(p[k]==Fa)continue;
dfs(p[k],x,dep+1);s[x]+=s[p[k]];
if(!son[x]||s[p[k]]>s[son[x]])son[x]=p[k];
}
}
inline void dfss(int x,int bh){
top[x]=bh;sx[x]=++ne;xs[sx[x]]=x;
if(!son[x])return;
dfss(son[x],bh);
for(int k=head[x];k;k=nex[k])if(p[k]!=son[x]&&p[k]!=fa[x])dfss(p[k],p[k]);
}
inline void clean(int nod){
if(add[nod]==-1)return;
if(lt[nod]!=rt[nod]){
t[nod*2]=1;add[nod*2]=ls[nod*2]=rs[nod*2]=add[nod];
t[nod*2+1]=1;add[nod*2+1]=ls[nod*2+1]=rs[nod*2+1]=add[nod];
}
add[nod]=-1;
}
inline void build(int l,int r,int nod){
lt[nod]=l;rt[nod]=r;add[nod]=-1;
if(l==r){t[nod]=1;ls[nod]=rs[nod]=v[xs[l]];return;}
int mid=l+r>>1;
build(l,mid,nod*2);build(mid+1,r,nod*2+1);
t[nod]=t[nod*2]+t[nod*2+1];ls[nod]=ls[nod*2];rs[nod]=rs[nod*2+1];
if(rs[nod*2]==ls[nod*2+1])t[nod]--;
}
inline void xg(int i,int j,int nod,int w){
clean(nod);
if(lt[nod]>=i&&rt[nod]<=j){
t[nod]=1;add[nod]=ls[nod]=rs[nod]=w;
return;
}
int mid=lt[nod]+rt[nod]>>1;
if(i<=mid)xg(i,j,nod*2,w);
if(j>mid)xg(i,j,nod*2+1,w);
t[nod]=t[nod*2]+t[nod*2+1];ls[nod]=ls[nod*2];rs[nod]=rs[nod*2+1];
if(rs[nod*2]==ls[nod*2+1])t[nod]--;
}
inline int ssum(int i,int j,int nod){
clean(nod);
if(lt[nod]>=i&&rt[nod]<=j)return t[nod];
int mid=lt[nod]+rt[nod]>>1,ans=0;
if(i<=mid)ans+=ssum(i,j,nod*2);
if(j>mid)ans+=ssum(i,j,nod*2+1);
if(i<=mid&&j>mid&&rs[nod*2]==ls[nod*2+1])ans--;
return ans;
}
inline int col(int x,int nod){
clean(nod);
if(lt[nod]==rt[nod])return ls[nod];
int mid=lt[nod]+rt[nod]>>1;
if(x<=mid)return col(x,nod*2);
else return col(x,nod*2+1);
}
inline int fsum(int x,int y){
int fx=top[x],fy=top[y],ans=0;
while(fx!=fy){
if(deep[fx]<deep[fy])swap(fx,fy),swap(x,y);
ans+=ssum(sx[fx],sx[x],1);
int cx=col(sx[fx],1),cy=col(sx[fa[fx]],1);
if(cx==cy)ans--;
x=fa[fx];fx=top[x];
}
if(deep[x]>deep[y])swap(x,y);
ans+=ssum(sx[x],sx[y],1);
return ans==0?1:ans;
}
inline void fxg(int x,int y,int z){
int fx=top[x],fy=top[y];
while(fx!=fy){
if(deep[fx]<deep[fy])swap(fx,fy),swap(x,y);
xg(sx[fx],sx[x],1,z);
x=fa[fx];fx=top[x];
}
if(deep[x]>deep[y])swap(x,y);
xg(sx[x],sx[y],1,z);
}
int main()
{
int m;
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)scanf("%d",&v[i]);
for(int i=1;i<n;i++){
int x,y;scanf("%d%d",&x,&y);
addedge(x,y);addedge(y,x);
}
dfs(1,0,1);dfss(1,1);build(1,n,1);
for(int i=1;i<=m;i++){
char c[5];int x,y,z;scanf("%s%d%d",c+1,&x,&y);
if(c[1]=='C'){
scanf("%d",&z);fxg(x,y,z);
}else printf("%d\n",fsum(x,y));
}
return 0;
}