题目链接:2243: [SDOI2011]染色
一眼就知道是个树剖……
对于线段树上每个区间,我们维护最左边的颜色、最右边的颜色、总颜色段数
合并区间的时候父区间的颜色段数=左右儿子的颜色段数和-左儿子最右边的颜色是否和有儿子最左边的颜色相等
然而我写炸了一上午QAQ
指针的线段树等于号写成了减号真是看不出来QAQ
令人鸡冻的代码:
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int maxn=1000000+10;
struct seg{
int l,r,sum;
seg *lc,*rc;
int cl,cr,data;
seg():/*sum(0),*//*cl(-1),cr(-1),*/data(-1){}
};
seg *root=new seg();
int n,m,pos[maxn],ind=0,Belong[maxn];
int a[maxn],col[maxn],fa[maxn][21];
int dep[maxn],size[maxn],h[maxn],tot=1;
struct edge{int to,next;}G[maxn*4];
bool vis[maxn];
void push_up(seg *p){
if (p->l+1==p->r) return;
p->cl=p->lc->cl; p->cr=p->rc->cr;
p->sum=p->lc->sum+p->rc->sum;
if (p->lc!=NULL&&p->rc!=NULL)
p->sum-=(p->lc->cr==p->rc->cl);
}
void push_down(seg *p){
if (p->data==-1||p->l+1==p->r) return;
p->lc->cl=p->lc->cr=p->data;
p->rc->cl=p->rc->cr=p->data;
p->lc->data=p->data;
p->rc->data=p->data;
p->data=-1;
p->lc->sum=1;
p->rc->sum=1;
}
void build(seg *p,int l,int r){
p->l=l; p->r=r;
if (l+1==r){
p->cl=p->cr=col[l];
p->sum=1; p->lc=NULL;
p->rc=NULL; return;
}else if (l+1<r){
p->lc=new seg();
p->rc=new seg();
int mid=(l+r)>>1;
if (l<mid) build(p->lc,l,mid);
else p->lc=NULL;
if (mid<r) build(p->rc,mid,r);
else p->rc=NULL;
push_up(p);
}
}
void add(int x,int y){
G[++tot].to=y; G[tot].next=h[x]; h[x]=tot;
}
void DFS1(int x,int deep){
dep[x]=deep; size[x]=1; vis[x]=1;
for (int i=1;i<=19;++i){
if (dep[x]<(1<<i)) break;
fa[x][i]=fa[fa[x][i-1]][i-1];
}
for (int i=h[x];i;i=G[i].next){
int v=G[i].to;
if (vis[v]) continue;
fa[v][0]=x;
DFS1(v,deep+1);
size[x]+=size[v];
}
}
void DFS2(int x,int L){
int k=0; ++ind; Belong[x]=L;
pos[x]=ind; col[ind]=a[x];
for (int i=h[x];i;i=G[i].next)
if (dep[x]<dep[G[i].to]&&size[G[i].to]>size[k])
k=G[i].to;
if (!k) return; DFS2(k,L);
for (int i=h[x];i;i=G[i].next)
if (dep[x]<dep[G[i].to]&&k!=G[i].to)
DFS2(G[i].to,G[i].to);
}
int query(int a,int b){
if (dep[a]<dep[b]) swap(a,b);
int t=dep[a]-dep[b];
for (int i=0;i<=19;++i)
if (t&(1<<i)) a=fa[a][i];
for (int i=19;~i;i--)
if (fa[a][i]!=fa[b][i]){
a=fa[a][i];b=fa[b][i];}
if (a==b) return a;
else return fa[a][0];
}
int ask(seg *p,int l,int r){
if (l<=p->l&&p->r<=r) return p->sum;
if (p->data!=-1) push_down(p);
int mid=(p->l+p->r)>>1;
if (l<mid&&mid<r){
int tmp=p->lc->cr==p->rc->cl?-1:0;
tmp+=ask(p->lc,l,r); tmp+=ask(p->rc,l,r); return tmp;
}else if (l<mid) {return ask(p->lc,l,r);
}else if (mid<r) return ask(p->rc,l,r);
}
int getcol(seg *p,int po){
if (p->l+1==p->r) return p->cl;
else {
if (p->data!=-1) push_down(p);
int mid=(p->l+p->r)>>1;
if (po<mid) return getcol(p->lc,po);
if (mid<=po) return getcol(p->rc,po);
}
}
int getans(int x,int y){
int sum=0;
while (Belong[x]!=Belong[y]){
sum+=ask(root,pos[Belong[x]],pos[x]+1);
int fat=fa[Belong[x]][0];
if (getcol(root,pos[Belong[x]])
==getcol(root,pos[fat])) sum--;
x=fat;
} sum+=ask(root,pos[y],pos[x]+1);
return sum;
}
void change(seg *p,int l,int r,int colo){
push_down(p);
if (l<=p->l&&p->r<=r){
p->cl=p->cr=colo;
p->sum=1; p->data=colo;
return;
}else{int mid=(p->l+p->r)>>1;
if (l<mid) change(p->lc,l,r,colo);
if (mid<r) change(p->rc,l,r,colo);
push_up(p);
}
}
void getchange(int x,int y,int z){
while (Belong[x]!=Belong[y]){
change(root,pos[Belong[x]],pos[x]+1,z);
x=fa[Belong[x]][0];
}change(root,pos[y],pos[x]+1,z);
}
int main(){
scanf("%d%d",&n,&m);
for (int i=1;i<=n;++i) scanf("%d",&a[i]);
for (int i=1;i<n;++i){
int x,y; scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
DFS1(1,0); DFS2(1,1);
build(root,1,n+1);
for (int i=1;i<=m;++i){
char s[10]; scanf("%s",s);
int x,y,z;
if (s[0]=='C'){
scanf("%d%d%d",&x,&y,&z);
int lca=query(x,y);
getchange(x,lca,z);
getchange(y,lca,z);
}else if (s[0]=='Q'){
scanf("%d%d",&x,&y);
int lca=query(x,y);
int tmp=getans(x,lca);
tmp=tmp+getans(y,lca)-1;
printf("%d\n",tmp);
}
}
}