洛谷P2486 [SDOI2011]染色(树链+线段树 + 树上区间合并 )
题意:
给定一棵 n 个节点的无根树,共有 m 个操作,操作分为两种:
将节点 a 到节点 b 的路径上的所有点(包括 a 和 b)都染成颜色 c。
询问节点 a 到节点 b 的路径上的颜色段数量。
颜色段的定义是极长的连续相同颜色被认为是一段。例如 112221 由三段组成:11、222、1
思路:
树上区间合并时,考虑上一次的边界
#include <bits/stdc++.h>
using namespace std;
#define lson rt<<1,l,mid
#define rson rt<<1|1,mid+1,r
typedef long long ll;
const int INF=0x3f3f3f3f;
const int maxn=2e5+100;
int n,m,k,RC,LC,tot,a[maxn],siz[maxn],fa[maxn],son[maxn],id[maxn],idd[maxn],top[maxn],deep[maxn],lazy[maxn<<2];
vector<int>mp[maxn];
struct NODE{
ll ans,sum,ml,mr,l,r;
}tree[maxn<<2];
void dfs1(int x,int f,int d){
siz[x]=1;
fa[x]=f;
deep[x]=d;
son[x]=0;
for(auto y:mp[x]){
if(y==f)continue;
dfs1(y,x,d+1);
siz[x]+=siz[y];
if(siz[son[x]]<siz[y])son[x]=y;
}
}
void dfs2(int x,int root){
id[x]=++tot;
top[x]=root;
idd[tot]=a[x];
if(son[x])dfs2(son[x],root);
for(auto y:mp[x]){
if(y!=fa[x]&&y!=son[x])dfs2(y,y);
}
}
void pushup(int rt){
tree[rt].sum=tree[rt<<1].sum+tree[rt<<1|1].sum;
if(tree[rt<<1].mr==tree[rt<<1|1].ml)tree[rt].sum--;
tree[rt].ml=tree[rt<<1].ml,tree[rt].mr=tree[rt<<1|1].mr;
}
void pushdown(int rt){
if(lazy[rt]){
tree[rt<<1].ml=tree[rt<<1|1].mr=lazy[rt];
tree[rt<<1].mr=tree[rt<<1|1].ml=lazy[rt];
tree[rt<<1].sum=tree[rt<<1|1].sum=1;
lazy[rt<<1]=lazy[rt<<1|1]=lazy[rt];
lazy[rt]=0;
}
}
void build(int rt,int l,int r){
tree[rt].l=l,tree[rt].r=r;
if(l==r){
tree[rt].ml=tree[rt].mr=idd[l];
tree[rt].sum=1;
return;
}
int mid=(l+r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
pushup(rt);
}
void update(int rt,int x,int y,int c){
if(x<=tree[rt].l&&tree[rt].r<=y){
tree[rt].ml=tree[rt].mr=c;
tree[rt].sum=1;
lazy[rt]=c;
return;
}
pushdown(rt);
int mid=(tree[rt].l+tree[rt].r)>>1;
if(x<=mid)update(rt<<1,x,y,c);
if(y>mid)update(rt<<1|1,x,y,c);
pushup(rt);
}
NODE query(int rt,int x,int y,int xx,int yy){
if(x<=tree[rt].l&&tree[rt].r<=y){
if(tree[rt].l==xx)LC=tree[rt].ml;
if(tree[rt].r==yy)RC=tree[rt].mr;
return tree[rt];
}
int mid=(tree[rt].l+tree[rt].r)>>1;
pushdown(rt);
if(y<=mid)return query(rt<<1,x,y,xx,yy);
else if(x>mid)return query(rt<<1|1,x,y,xx,yy);
else{
NODE t,t1=query(rt<<1,x,mid,xx,yy),t2=query(rt<<1|1,mid+1,y,xx,yy);
t.ml=t1.ml;
t.mr=t2.mr;
t.sum=t1.sum+t2.sum;
if(t1.mr==t2.ml)t.sum--;
return t;
}
}
int getdis(int x,int y){//诗
int ans=0,pos1=0,pos2=0;//pos1为x,深度大的点的上端点颜色
while(top[x]!=top[y]){
if(deep[top[x]]<deep[top[y]])swap(x,y),swap(pos1,pos2);
ans+=query(1,id[top[x]],id[x],id[top[x]],id[x]).sum;
if(RC==pos1)ans--;
pos1=LC;
x=fa[top[x]];
}
if(deep[x]>deep[y])swap(x,y),swap(pos1,pos2);
ans+=query(1,id[x],id[y],id[x],id[y]).sum;
if(LC==pos1)ans--;
if(RC==pos2)ans--;
return ans;
}
void upday(int x,int y,int c){//x到y区间修改
while(top[x]!=top[y]){
if(deep[top[x]]<deep[top[y]])swap(x,y);
update(1,id[top[x]],id[x],c);
x=fa[top[x]];
}
if(deep[x]>deep[y])swap(x,y);
update(1,id[x],id[y],c);
}
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
}
for(int i=1,x,y;i<n;i++){
scanf("%d%d",&x,&y);
mp[x].push_back(y),mp[y].push_back(x);
}
dfs1(1,0,1);
dfs2(1,1);
build(1,1,n);
for(int i=1,x,y,z;i<=m;i++){
char op[10];
scanf("%s%d%d",op,&x,&y);
if(op[0]=='Q')printf("%d\n",getdis(x,y));
else{
scanf("%d",&z);
upday(x,y,z);
}
}
}