题目链接:https://www.luogu.org/problemnew/show/P2486
题目大意:
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),
如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
题目思路:
此题最大的难点在于查询。线段树一共维护三个量,线段上的颜色段数量,左端点颜色,右端点颜色。那么可以发现,当pushup的时候,如果左儿子的右端点如果和右儿子的左端点颜色相同,那么颜色段数量可以-1。最难的对于查询,这里重点讲解。首先树链剖分的本质是将x到y的路径分成很多条链,那么对于本题来说,主要的问题在于每条链之间的衔接。首先跟pushup一样,我们需要得到每条链的端点,来得到是否需要-1,colx coly分别表示链的顶端是什么颜色,由于每次colx和coly都是一起交换,所以我们不需要知道colx到底代表的是哪个点在往上跳,只用知道他跟x绑定起来了就可以了。这样Rc就是我们当前查询的链的最底端,colx就是之前存的之前那条链的最右端,两者比较即可。到最后跳到一条链的时候,需要同时判断两个端点的情况。线段树的查询要注意,如果l<=mid&&r>mid,那么很明显该区间也需要判端点,用上面的方法判断即可。
以下是代码:
#include <bits/stdc++.h>
using namespace std;
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define per(i,a,b) for(int i=a;i>=b;i--)
#define ll long long
#define inf 0x3f3f3f3f
const int MAXN = 1e5+5;
int n,m,b[MAXN],x,y,z;
vector<int>v[MAXN];
int dfn[MAXN],rk[MAXN],top[MAXN],siz[MAXN],fa[MAXN],son[MAXN],dep[MAXN],tnt;
int Lc,Rc;
char s[MAXN];
struct node{
int l,r,val,lcol,rcol,mark;
}a[MAXN<<2];
void dfs1(int u,int f){
int len=v[u].size();
dep[u]=dep[f]+1;
fa[u]=f;
siz[u]=1;
son[u]=0;
rep(i,0,len-1){
int to=v[u][i];
if(to==f)continue;
dfs1(to,u);
siz[u]+=siz[to];
if(siz[to]>siz[son[u]]){
son[u]=to;
}
}
}
void dfs2(int u,int tp){
int len=v[u].size();
dfn[u]=++tnt,rk[tnt]=u;
top[u]=tp;
if(son[u])dfs2(son[u],tp);
rep(i,0,len-1){
int to=v[u][i];
if(to==fa[u]||to==son[u])continue;
dfs2(to,to);
}
}
void build(int rt,int l,int r){
a[rt].l=l,a[rt].r=r,a[rt].mark=-1;
if(l==r){
a[rt].val=1;
a[rt].lcol=a[rt].rcol=b[rk[l]];
return;
}
int mid=(l+r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
a[rt].val=a[rt<<1].val+a[rt<<1|1].val;
a[rt].lcol=a[rt<<1].lcol;
a[rt].rcol=a[rt<<1|1].rcol;
if(a[rt<<1].rcol==a[rt<<1|1].lcol)a[rt].val--;
}
void spread(int rt){
if(a[rt].mark!=-1){
a[rt<<1].val=a[rt<<1|1].val=1;
a[rt<<1].lcol=a[rt<<1|1].lcol=a[rt].mark;
a[rt<<1].rcol=a[rt<<1|1].rcol=a[rt].mark;
a[rt<<1].mark=a[rt<<1|1].mark=a[rt].mark;
a[rt].mark=-1;
}
}
void update(int rt,int l,int r,int val){
if(a[rt].l>=l&&a[rt].r<=r){
a[rt].val=1;
a[rt].lcol=a[rt].rcol=val;
a[rt].mark=val;
return ;
}
spread(rt);
int mid=(a[rt].l+a[rt].r)>>1;
if(l<=mid)update(rt<<1,l,r,val);
if(r>mid)update(rt<<1|1,l,r,val);
a[rt].val=a[rt<<1].val+a[rt<<1|1].val;
a[rt].lcol=a[rt<<1].lcol;
a[rt].rcol=a[rt<<1|1].rcol;
if(a[rt<<1].rcol==a[rt<<1|1].lcol)a[rt].val--;
}
int query(int rt,int l,int r){
if(a[rt].l==l)Lc=a[rt].lcol;
if(a[rt].r==r)Rc=a[rt].rcol;
if(a[rt].l>=l&&a[rt].r<=r){
return a[rt].val;
}
spread(rt);
int mid=(a[rt].l+a[rt].r)>>1,ans=0,flag=0;
if(l<=mid)ans+=query(rt<<1,l,r),flag++;
if(r>mid)ans+=query(rt<<1|1,l,r),flag++;
if(flag==2){
if(a[rt<<1].rcol==a[rt<<1|1].lcol)ans--;
}
return ans;
}
void update1(int x,int y,int val){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
update(1,dfn[top[x]],dfn[x],val);
x=fa[top[x]];
}
if(dfn[x]>dfn[y])swap(x,y);
update(1,dfn[x],dfn[y],val);
}
int query1(int x,int y){
int ans=0,colx=-1,coly=-1;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y),swap(colx,coly);
ans+=query(1,dfn[top[x]],dfn[x]);
if(Rc==colx)ans--;
colx=Lc;
x=fa[top[x]];
}
if(dfn[x]>dfn[y])swap(x,y),swap(colx,coly);
ans+=query(1,dfn[x],dfn[y]);
if(Rc==coly)ans--;
if(Lc==colx)ans--;
return ans;
}
int main(){
while(~scanf("%d%d",&n,&m)){
rep(i,1,n){
scanf("%d",&b[i]);
v[i].clear();
}
rep(i,1,n-1){
scanf("%d%d",&x,&y);
v[x].push_back(y);
v[y].push_back(x);
}
dep[0]=0,tnt=0;
dfs1(1,0);
dfs2(1,1);
build(1,1,n);
rep(i,1,m){
scanf("%s",s);
if(s[0]=='Q'){
scanf("%d%d",&x,&y);
int ans=query1(x,y);
printf("%d\n",ans);
}
else{
scanf("%d%d%d",&x,&y,&z);
update1(x,y,z);
}
}
}
return 0;
}